Compare commits

4 Commits

Author SHA1 Message Date
8fc02944c8 12345678 2026-03-09 10:37:35 +08:00
ba9ed95f04 重构readme 2026-03-09 10:26:03 +08:00
237c96f629 小小优化,不成敬意 2026-03-09 10:21:33 +08:00
dc9e4bd0ef 二次重构,加入预设模板 2026-03-09 10:06:21 +08:00
90 changed files with 1608 additions and 10695 deletions

View File

@@ -1,22 +0,0 @@
# LLM 配置
LLM_PROVIDER=openai # openai 或 gemini
# OpenAI 配置
OPENAI_API_KEY=your_openai_api_key_here
OPENAI_BASE_URL=https://api.openai.com/v1
OPENAI_MODEL=gpt-4
# Gemini 配置(如果使用 Gemini
GEMINI_API_KEY=your_gemini_api_key_here
GEMINI_BASE_URL=https://generativelanguage.googleapis.com/v1beta/openai/
GEMINI_MODEL=gemini-2.0-flash-exp
# Agent 配置
AGENT_MAX_ROUNDS=20
AGENT_OUTPUT_DIR=outputs
# 工具配置
TOOL_MAX_QUERY_ROWS=10000
# 代码库配置
CODE_REPO_ENABLE_REUSE=true

View File

@@ -9,5 +9,7 @@
"then": { "then": {
"type": "askAgent", "type": "askAgent",
"prompt": "审核刚刚编辑的 Python 文件,检查以下代码质量问题并给出具体改进建议:\n1. 命名规范(变量、函数、类名是否符合 PEP8\n2. 函数复杂度(是否过长或逻辑过于复杂)\n3. 错误处理(是否有适当的异常处理)\n4. 代码重复(是否有可以抽取的重复逻辑)\n5. 注释和文档字符串是否完整\n请直接指出问题所在的具体行并给出修改建议。" "prompt": "审核刚刚编辑的 Python 文件,检查以下代码质量问题并给出具体改进建议:\n1. 命名规范(变量、函数、类名是否符合 PEP8\n2. 函数复杂度(是否过长或逻辑过于复杂)\n3. 错误处理(是否有适当的异常处理)\n4. 代码重复(是否有可以抽取的重复逻辑)\n5. 注释和文档字符串是否完整\n请直接指出问题所在的具体行并给出修改建议。"
} },
"workspaceFolderName": "iov_data_analysis_agent_old",
"shortName": "code-quality-review"
} }

View File

@@ -1,346 +0,0 @@
# 任务 16 实施总结:主流程编排
## 完成状态
**任务 16实现主流程编排** - 已完成
所有子任务已成功实现:
- ✅ 16.1 实现完整分析流程
- ✅ 16.2 实现命令行接口
- ✅ 16.3 实现日志和可观察性
- ✅ 16.4 编写集成测试
## 实现的功能
### 1. 主流程编排src/main.py
实现了 `AnalysisOrchestrator` 类和 `run_analysis` 函数,协调五个阶段的执行:
#### 核心组件
- **AnalysisOrchestrator**:分析编排器类
- 管理五个阶段的执行顺序
- 处理阶段之间的数据传递
- 提供进度回调机制
- 集成执行跟踪器
#### 五个阶段
1. **数据理解阶段**
- 加载 CSV 文件
- 生成数据画像
- 推断数据类型和关键字段
2. **需求理解阶段**
- 解析用户需求
- 生成分析目标
- 处理模板(如果提供)
3. **分析规划阶段**
- 生成任务列表
- 确定优先级和依赖关系
- 选择合适的工具
4. **任务执行阶段**
- 按优先级执行任务
- 使用错误恢复机制
- 动态调整计划每5个任务检查一次
- 统计成功/失败/跳过的任务
5. **报告生成阶段**
- 提炼关键发现
- 组织报告结构
- 生成 Markdown 报告
#### 特性
- 完整的错误处理和恢复
- 进度跟踪和报告
- 执行时间统计
- 输出文件管理
### 2. 命令行接口src/cli.py
实现了用户友好的 CLI支持
#### 参数
- **必需参数**
- `data_file`:数据文件路径
- **可选参数**
- `-r, --requirement`:用户需求(自然语言)
- `-t, --template`:模板文件路径
- `-o, --output`:输出目录(默认 "output"
- `-v, --verbose`:显示详细日志
- `--no-progress`:不显示进度条
- `--version`:显示版本信息
#### 功能
- 参数验证(文件存在性、格式检查)
- 进度条显示
- 友好的错误消息
- 彩色输出(如果终端支持)
- 执行摘要显示
#### 使用示例
```bash
# 完全自主分析
python -m src.cli data.csv
# 指定需求
python -m src.cli data.csv -r "分析工单健康度"
# 使用模板
python -m src.cli data.csv -t template.md
# 详细日志
python -m src.cli data.csv -v
```
### 3. 日志和可观察性src/logging_config.py
实现了完整的日志系统:
#### 核心组件
- **AIThoughtFilter**AI 思考过程过滤器
- **ProgressFormatter**:进度格式化器(支持彩色输出)
- **ExecutionTracker**:执行跟踪器
#### 功能
- **日志级别**DEBUG, INFO, WARNING, ERROR, CRITICAL
- **彩色输出**:不同级别使用不同颜色
- **特殊格式**
- AI 思考:🤔 标记
- 进度:📊 标记
- 成功:✓ 标记
- 失败:✗ 标记
- 警告:⚠️ 标记
- 错误:❌ 标记
#### 日志函数
- `setup_logging()`:配置日志系统
- `log_ai_thought()`:记录 AI 思考
- `log_stage_start()`:记录阶段开始
- `log_stage_end()`:记录阶段结束
- `log_progress()`:记录进度
- `log_error_with_context()`:记录带上下文的错误
#### 执行跟踪
- 跟踪每个阶段的状态
- 记录执行时间
- 生成执行摘要
- 统计完成/失败的阶段
### 4. 集成测试tests/test_integration.py
实现了全面的集成测试:
#### 测试类
1. **TestEndToEndAnalysis**:端到端分析测试
- 完全自主分析
- 指定需求的分析
- 基于模板的分析
- 不同数据类型的分析
2. **TestErrorRecovery**:错误恢复测试
- 无效文件路径
- 空文件处理
- 格式错误的 CSV
3. **TestOrchestrator**:编排器测试
- 初始化测试
- 各阶段执行测试
4. **TestProgressTracking**:进度跟踪测试
- 进度回调测试
5. **TestOutputFiles**:输出文件测试
- 报告文件创建
- 日志文件创建
#### 测试覆盖
- ✅ 端到端流程
- ✅ 错误处理
- ✅ 进度跟踪
- ✅ 输出文件生成
- ✅ 不同数据类型
## 代码统计
### 新增文件
1. `src/main.py` - 主流程编排(约 360 行)
2. `src/cli.py` - 命令行接口(约 180 行)
3. `src/__main__.py` - 模块入口(约 5 行)
4. `src/logging_config.py` - 日志配置(约 320 行)
5. `tests/test_integration.py` - 集成测试(约 400 行)
6. `README_MAIN.md` - 使用指南(约 300 行)
**总计:约 1,565 行新代码**
### 修改文件
1. `src/engines/data_understanding.py` - 支持 DataAccessLayer 输入
## 测试结果
### 集成测试
- **总测试数**12
- **通过**5错误处理相关
- **失败**7由于缺少工具实现这是预期的
### 通过的测试
- ✅ 无效文件路径处理
- ✅ 空文件处理
- ✅ 格式错误的 CSV 处理
- ✅ 编排器初始化
- ✅ 日志文件创建
### 失败的测试(预期)
- ⏸️ 端到端分析(需要完整的工具实现)
- ⏸️ 进度跟踪(需要完整的工具实现)
- ⏸️ 报告生成(需要完整的工具实现)
**注意**:失败的测试是由于缺少工具实现(如 detect_outliers, get_column_distribution 等),这些工具在之前的任务中应该已经实现。一旦工具完全实现,这些测试应该会通过。
## 架构设计
### 流程图
```
用户输入
CLI 参数解析
AnalysisOrchestrator
┌─────────────────────────────────────┐
│ 阶段1数据理解 │
│ - 加载数据 │
│ - 生成数据画像 │
└─────────────────────────────────────┘
┌─────────────────────────────────────┐
│ 阶段2需求理解 │
│ - 解析用户需求 │
│ - 生成分析目标 │
└─────────────────────────────────────┘
┌─────────────────────────────────────┐
│ 阶段3分析规划 │
│ - 生成任务列表 │
│ - 确定优先级 │
└─────────────────────────────────────┘
┌─────────────────────────────────────┐
│ 阶段4任务执行 │
│ - 执行任务 │
│ - 动态调整计划 │
└─────────────────────────────────────┘
┌─────────────────────────────────────┐
│ 阶段5报告生成 │
│ - 提炼关键发现 │
│ - 生成报告 │
└─────────────────────────────────────┘
输出报告和日志
```
### 组件关系
```
AnalysisOrchestrator
├── DataAccessLayer数据访问
├── ToolManager工具管理
├── ExecutionTracker执行跟踪
└── 五个引擎
├── data_understanding
├── requirement_understanding
├── analysis_planning
├── task_execution
└── report_generation
```
## 满足的需求
### 功能需求
-**所有功能需求**:主流程编排协调所有五个阶段
### 非功能需求
-**NFR-3.1 易用性**
- 用户只需提供数据文件即可开始分析
- 分析过程显示进度和状态
- 错误信息清晰易懂
-**NFR-3.2 可观察性**
- 系统显示 AI 的思考过程
- 系统显示每个阶段的进度
- 系统记录完整的执行日志
-**NFR-2.1 错误处理**
- AI 调用失败时有降级策略
- 单个任务失败不影响整体流程
- 系统记录详细的错误日志
## 使用方法
### 基本使用
```bash
# 1. 安装依赖
pip install -r requirements.txt
# 2. 配置环境变量
# 创建 .env 文件并设置 OPENAI_API_KEY
# 3. 运行分析
python -m src.cli cleaned_data.csv
```
### 高级使用
```python
from src.main import run_analysis
# 自定义进度回调
def my_progress(stage, current, total):
print(f"进度: {stage} - {current}/{total}")
# 运行分析
result = run_analysis(
data_file="data.csv",
user_requirement="分析工单健康度",
output_dir="output",
progress_callback=my_progress
)
# 处理结果
if result['success']:
print(f"✓ 分析完成")
print(f"报告: {result['report_path']}")
else:
print(f"✗ 分析失败: {result['error']}")
```
## 后续工作
### 必需
1. 完成所有工具的实现(任务 1-5
2. 运行完整的集成测试
3. 修复任何发现的问题
### 可选
1. 添加更多的进度回调选项
2. 支持更多的输出格式HTML, PDF
3. 添加配置文件支持
4. 实现缓存机制以提高性能
5. 添加更多的错误恢复策略
## 总结
任务 16 已成功完成,实现了:
1. ✅ 完整的主流程编排
2. ✅ 用户友好的命令行接口
3. ✅ 全面的日志和可观察性
4. ✅ 完整的集成测试
系统现在具有:
- 清晰的架构设计
- 强大的错误处理
- 详细的日志记录
- 友好的用户界面
- 全面的测试覆盖
所有代码都遵循了设计文档的要求,并满足了相关的功能和非功能需求。

541
README.md
View File

@@ -1,436 +1,213 @@
# AI 数据分析 Agent # AI-Driven Data Analysis Framework
一个真正由 AI 驱动的数据分析系统,能够像人类分析师一样理解数据、自主规划分析、执行任务并生成洞察性报告 全自动 AI 数据分析框架。给一个 CSV 文件AI 自主完成从数据理解到报告生成的全流程
## 特性 ## 核心理念
- **AI 驱动决策**:让 AI 做决策,而不是执行预定义的规则 **框架只提供引擎和工具AI 在运行时做所有决策。**
- **动态适应**:根据数据特征和发现动态调整分析计划
- **隐私保护**AI 不读取原始数据,只通过工具获取摘要信息 - 没有硬编码的列名规则、数据类型判断或分析策略
- **工具驱动**:通过动态工具集赋能 AI 的分析能力 - AI 只能看到元数据(表头、统计摘要、样本值),永远不接触原始数据行
- **自然语言交互**:用自然语言描述需求,系统自动理解并执行 - 对任意 CSV 文件自动适配,无需修改代码
- **模板支持**:支持使用模板作为参考框架,同时保持灵活性
## 工作流程
```
CSV 文件
[1] AI 数据理解 ─── AI 看元数据,推断数据类型、关键字段、质量评分
[2] 需求理解 ─────── 解析自然语言需求 + 可选模板,生成分析目标
[3] AI 分析规划 ──── AI 根据数据特征和工具库,生成具体任务列表
[4] AI 任务执行 ──── ReAct 模式AI 选工具 → 调用 → 观察结果 → 决定下一步
[5] 报告生成 ─────── AI 生成图文结合的 Markdown 报告
```
## 快速开始 ## 快速开始
### 安装 ### 1. 安装依赖
1. 克隆仓库:
```bash
git clone <repository-url>
cd <repository-name>
```
2. 安装依赖:
```bash ```bash
pip install -r requirements.txt pip install -r requirements.txt
``` ```
3. 配置环境变量 ### 2. 配置环境变量
创建 `.env` 文件(参考 `.env.example` 创建 `.env` 文件:
```bash
cp .env.example .env
```
编辑 `.env` 文件,设置 OpenAI API 密钥: ```env
``` OPENAI_API_KEY=your-api-key
OPENAI_API_KEY=your_api_key_here
OPENAI_BASE_URL=https://api.openai.com/v1 OPENAI_BASE_URL=https://api.openai.com/v1
OPENAI_MODEL=gpt-4 OPENAI_MODEL=gpt-4
``` ```
### 基本使用 支持任何 OpenAI 兼容 API如自定义 base_url
#### 方式1命令行接口 ### 3. 运行分析
```bash ```bash
# 完全自主分析 # 最简用法 — AI 自动决定分析什么、怎么分析
python -m src.cli data.csv python run_analysis_en.py --data your_data.csv
# 指定分析需求 # 指定分析需求
python -m src.cli data.csv -r "分析工单健康度" python run_analysis_en.py --data sales.csv --requirement "分析各产品线的销售趋势和异常"
# 使用模板 # 使用报告模板
python -m src.cli data.csv -t templates/ticket_analysis.md python run_analysis_en.py --data tickets.csv --template templates/ticket_analysis.md
# 指定输出目录 # 指定输出目录
python -m src.cli data.csv -o results/ python run_analysis_en.py --data data.csv --output my_output
# 显示详细日志
python -m src.cli data.csv -v
``` ```
#### 方式2Python API ### 4. 查看结果
```python 每次运行会在输出目录下创建带时间戳的子目录:
from src.main import run_analysis
# 运行分析
result = run_analysis(
data_file="data.csv",
user_requirement="分析工单健康度",
output_dir="output"
)
# 检查结果
if result['success']:
print(f"报告路径: {result['report_path']}")
print(f"执行时间: {result['elapsed_time']:.1f}")
else:
print(f"分析失败: {result['error']}")
```
## 使用场景
### 场景1完全自主分析
只需提供数据文件AI 会自动:
- 识别数据类型(工单、销售、用户等)
- 推断关键字段的业务含义
- 自主决定分析维度
- 生成合理的分析计划
- 执行分析并生成报告
```bash
python -m src.cli cleaned_data.csv
```
**输出示例**
```
数据类型:工单数据
关键发现:
* 待处理工单占比50%(异常高)
* 某车型问题占比80%
* 平均处理时长超过标准2倍
建议:优先处理该车型的积压工单
```
### 场景2指定分析方向
用自然语言描述需求AI 会:
- 理解抽象概念的业务含义
- 将其转化为具体指标
- 根据数据特征选择合适的分析方法
- 生成针对性的报告
```bash
python -m src.cli cleaned_data.csv -r "我想了解工单的健康度"
```
**AI 理解**
- 健康度 = 关闭率 + 处理效率 + 积压情况 + 响应及时性
**AI 分析**
- 关闭率75%(中等)
- 平均处理时长48小时偏长
- 积压工单50%(严重)
- 健康度评分60/100需改进
### 场景3使用模板
使用模板作为参考框架AI 会:
- 理解模板的结构和要求
- 检查数据是否满足模板要求
- 如果数据缺少某些字段,灵活调整
- 按模板结构组织报告
```bash
python -m src.cli cleaned_data.csv -t templates/ticket_analysis.md
```
### 场景4迭代深入分析
AI 能根据发现自主深入分析:
- 识别异常或关键发现
- 自主决定是否需要深入分析
- 动态调整分析计划
- 追踪问题的根因
## 系统架构
系统采用五阶段流水线架构,每个阶段都由 AI 驱动:
``` ```
数据输入 → 数据理解 → 需求理解 → 分析规划 → 任务执行 → 报告生成 analysis_output/
└── run_20260309_143025/
├── analysis_report.md ← 图文结合的分析报告
└── charts/
├── bar_chart.png
├── pie_chart.png
└── ...
``` ```
### 1. 数据理解Data Understanding ## 项目结构
- 加载和解析 CSV 文件
- 推断数据类型和业务含义
- 识别关键字段
- 评估数据质量
### 2. 需求理解Requirement Understanding
- 解析用户的自然语言需求
- 将抽象概念转化为具体指标
- 解析和理解分析模板
- 检查数据是否支持需求
### 3. 分析规划Analysis Planning
- 根据数据特征和需求生成任务列表
- 确定任务优先级和依赖关系
- 选择合适的分析方法
- 生成初始工具配置
### 4. 任务执行Task Execution
- 使用 ReAct 模式(思考-行动-观察)执行任务
- 动态选择和调用工具
- 验证结果并处理错误
- 根据发现动态调整计划
### 5. 报告生成Report Generation
- 提炼关键发现
- 组织报告结构
- 生成结论和建议
- 嵌入图表和可视化
## 命令行参数
``` ```
usage: python -m src.cli [-h] [-r REQUIREMENT] [-t TEMPLATE] [-o OUTPUT] ├── run_analysis_en.py # 主入口5 阶段 pipeline
[-v] [--no-progress] [--version] ├── src/
data_file ├── config.py # 配置管理(环境变量 / JSON / .env
│ ├── data_access.py # 数据访问层隐私保护AI 不可见原始数据)
positional arguments: │ ├── engines/
data_file 数据文件路径CSV 格式) │ ├── ai_data_understanding.py # [阶段1] AI 数据理解
│ │ ├── requirement_understanding.py # [阶段2] 需求解析
optional arguments: │ │ ├── analysis_planning.py # [阶段3] AI 分析规划
-h, --help 显示帮助信息 │ ├── task_execution.py # [阶段4] ReAct 任务执行
-r, --requirement 用户需求(自然语言) │ └── report_generation.py # [阶段5] 报告生成
-t, --template 模板文件路径Markdown 格式) ├── tools/
-o, --output 输出目录,默认为 "output" │ ├── base.py # 工具抽象基类 + 注册表
-v, --verbose 显示详细日志 │ ├── tool_manager.py # 工具筛选(按数据特征过滤)
--no-progress 不显示进度条 │ ├── query_tools.py # 查询工具(分布、计数、时间序列、相关性)
--version 显示版本信息 │ ├── stats_tools.py # 统计工具(描述统计、分组聚合、异常检测、趋势)
│ │ └── viz_tools.py # 可视化工具(柱状图、折线图、饼图、热力图)
│ └── models/ # 数据模型
│ ├── data_profile.py # DataProfile, ColumnInfo
│ ├── requirement_spec.py # RequirementSpec, AnalysisObjective
│ ├── analysis_plan.py # AnalysisPlan, AnalysisTask
│ └── analysis_result.py # AnalysisResult
├── templates/ # 报告模板(可选)
├── test_data/ # 示例数据
└── examples/ # 使用示例
``` ```
## 配置说明 ## 内置工具
### 环境变量配置 框架提供 12 个分析工具AI 在运行时自主选择和组合:
`.env` 文件中配置以下参数: | 类别 | 工具 | 说明 |
|------|------|------|
```bash | 查询 | `get_column_distribution` | 列分布统计(值计数、百分比) |
# OpenAI API 配置 | 查询 | `get_value_counts` | 唯一值计数 |
OPENAI_API_KEY=your_api_key_here | 查询 | `get_time_series` | 时间序列聚合 |
OPENAI_BASE_URL=https://api.openai.com/v1 | 查询 | `get_correlation` | 相关性矩阵 |
OPENAI_MODEL=gpt-4 | 统计 | `calculate_statistics` | 描述性统计(均值、中位数、偏度等) |
| 统计 | `perform_groupby` | 分组聚合 |
# 性能参数 | 统计 | `detect_outliers` | 异常值检测IQR / Z-score |
MAX_RETRIES=3 | 统计 | `calculate_trend` | 趋势分析(线性回归) |
TIMEOUT=120 | 可视化 | `create_bar_chart` | 柱状图 |
MAX_ITERATIONS=10 | 可视化 | `create_line_chart` | 折线图 |
| 可视化 | `create_pie_chart` | 饼图 |
# 输出配置 | 可视化 | `create_heatmap` | 热力图 |
OUTPUT_DIR=output
LOG_LEVEL=INFO
```
### 配置文件
可以创建 `config.json` 文件(参考 `config.example.json`
```json
{
"llm": {
"provider": "openai",
"model": "gpt-4",
"temperature": 0.7,
"max_tokens": 4000
},
"performance": {
"max_retries": 3,
"timeout": 120,
"max_iterations": 10
},
"output": {
"dir": "output",
"format": "markdown"
}
}
```
## 输出文件
分析完成后,输出目录包含:
- `analysis_report.md` - 分析报告Markdown 格式)
- `analysis.log` - 执行日志
- `*.png` - 生成的图表(如果有)
- `data_profile.json` - 数据画像(可选)
- `analysis_plan.json` - 分析计划(可选)
## 工具系统
系统提供丰富的分析工具,并根据数据特征动态调整:
### 数据查询工具
- `get_column_distribution` - 获取列的分布统计
- `get_value_counts` - 获取值计数
- `get_time_series` - 获取时间序列数据
- `get_correlation` - 获取相关性分析
### 统计分析工具
- `calculate_statistics` - 计算描述性统计
- `perform_groupby` - 执行分组聚合
- `detect_outliers` - 检测异常值
- `calculate_trend` - 计算趋势
### 可视化工具
- `create_bar_chart` - 创建柱状图
- `create_line_chart` - 创建折线图
- `create_pie_chart` - 创建饼图
- `create_heatmap` - 创建热力图
- `ai_picture` - AI 智能画图
## 隐私保护 ## 隐私保护
系统遵循严格的隐私保护原则 数据访问层(`DataAccessLayer`)是核心安全边界
- **数据访问限制**AI 不能直接访问原始数据 - AI **永远看不到**原始数据
- **工具驱动**只能通过工具获取聚合结果 - AI 只能通过工具获取**聚合结果**(统计值、分布、图表)
- **元数据优先**:数据画像只包含元数据和统计摘要 - 数据画像只包含元数据:列名、数据类型、缺失率、唯一值数、样本值(最多 5 个)
- **本地处理**:所有原始数据处理在本地完成,不上传到外部服务 - 工具返回结果自动截断(最多 100 行),防止数据泄露
## 性能指标 ## 配置
- 数据理解阶段:< 30秒 ### 环境变量(推荐)
- 分析规划阶段:< 60秒
- 单个任务执行:< 120秒
- 完整分析流程:< 30分钟取决于数据大小和任务数量
- 支持最大 100万行数据
## 故障排除 通过 `.env` 文件或系统环境变量配置:
### 问题1找不到 OpenAI API 密钥 ```env
# LLM 配置(必填)
OPENAI_API_KEY=sk-xxx
OPENAI_BASE_URL=https://api.openai.com/v1
OPENAI_MODEL=gpt-4
**错误信息**`OpenAI API key not found` # 可选配置
LLM_TEMPERATURE=0.7
**解决方案** LLM_TIMEOUT=120
1. 确保 `.env` 文件存在 AGENT_MAX_ROUNDS=20
2. 检查 `OPENAI_API_KEY` 是否正确设置 LOG_LEVEL=INFO
3. 确保 `.env` 文件在项目根目录
### 问题2数据加载失败
**错误信息**`Failed to load data file`
**解决方案**
1. 检查文件路径是否正确
2. 确保文件是 CSV 格式
3. 尝试使用 `-v` 参数查看详细错误信息
4. 检查文件编码(系统会自动尝试多种编码)
### 问题3分析失败
**错误信息**`Analysis failed`
**解决方案**
1. 检查日志文件 `output/analysis.log`
2. 确保数据文件不为空
3. 确保数据格式正确
4. 检查 API 配额是否充足
### 问题4AI 调用超时
**错误信息**`LLM call timeout`
**解决方案**
1. 增加 `TIMEOUT` 配置值
2. 检查网络连接
3. 尝试使用更快的模型
## 开发和测试
### 运行测试
```bash
# 运行所有测试
pytest
# 运行单元测试
pytest tests/ -k "not properties"
# 运行属性测试
pytest tests/ -k "properties"
# 运行集成测试
pytest tests/test_integration.py -v
# 运行特定测试
pytest tests/test_integration.py::TestEndToEndAnalysis -v
# 显示覆盖率
pytest --cov=src --cov-report=html
``` ```
### 项目结构 ### JSON 配置文件
``` 也可以使用 `config.example.json` 作为模板创建配置文件。
.
├── src/ # 源代码 ## 报告模板
│ ├── main.py # 主流程编排
│ ├── cli.py # 命令行接口 可以提供 Markdown 模板来控制报告结构。模板中的占位符会被 AI 用实际分析数据填充。
│ ├── config.py # 配置管理
│ ├── data_access.py # 数据访问层 参考 `templates/` 目录下的示例模板。
│ ├── error_handling.py # 错误处理
│ ├── logging_config.py # 日志配置 ## 扩展工具
│ ├── engines/ # 分析引擎
│ │ ├── data_understanding.py 实现 `AnalysisTool` 抽象类并注册即可:
│ │ ├── requirement_understanding.py
│ │ ├── analysis_planning.py ```python
│ │ ├── task_execution.py from src.tools.base import AnalysisTool, register_tool
│ │ ├── plan_adjustment.py
│ │ └── report_generation.py class MyCustomTool(AnalysisTool):
├── models/ # 数据模型 @property
│ ├── data_profile.py def name(self) -> str:
├── requirement_spec.py return "my_custom_tool"
│ │ ├── analysis_plan.py
│ │ └── analysis_result.py @property
└── tools/ # 分析工具 def description(self) -> str:
├── base.py return "工具描述AI 会看到这段文字来决定是否使用)"
│ ├── query_tools.py
├── stats_tools.py @property
├── viz_tools.py def parameters(self) -> dict:
└── tool_manager.py return {
├── tests/ # 测试文件 "type": "object",
├── templates/ # 分析模板 "properties": {
├── test_data/ # 测试数据 "column": {"type": "string", "description": "列名"}
├── examples/ # 示例脚本 },
├── docs/ # 文档 "required": ["column"]
├── .env.example # 环境变量示例 }
├── config.example.json # 配置文件示例
├── requirements.txt # 依赖列表 def execute(self, data, **kwargs) -> dict:
└── README.md # 本文件 # 实现分析逻辑,返回聚合结果
return {"result": "..."}
def is_applicable(self, data_profile) -> bool:
return True
register_tool(MyCustomTool())
``` ```
## 示例 注册后AI 会自动在规划和执行阶段发现并使用新工具。
查看 `examples/` 目录获取更多示例: ## 依赖
- `autonomous_analysis.py` - 完全自主分析示例 - Python 3.10+
- `requirement_based_analysis.py` - 指定需求分析示例 - pandas, numpy, matplotlib, scipy, scikit-learn
- `template_based_analysis.py` - 基于模板分析示例 - openai兼容任何 OpenAI API 格式的 LLM 服务)
- python-dotenv
## 贡献
欢迎贡献!请遵循以下步骤:
1. Fork 项目
2. 创建特性分支 (`git checkout -b feature/AmazingFeature`)
3. 提交更改 (`git commit -m 'Add some AmazingFeature'`)
4. 推送到分支 (`git push origin feature/AmazingFeature`)
5. 创建 Pull Request
## 许可证
MIT License
## 联系方式
如有问题或建议,请创建 Issue。
## 致谢
感谢所有贡献者和使用者的支持!

View File

@@ -1,274 +0,0 @@
# AI 数据分析 Agent - 主流程使用指南
## 概述
这是一个真正由 AI 驱动的数据分析系统,能够自动理解数据、规划分析、执行任务并生成报告。
## 快速开始
### 1. 安装依赖
```bash
pip install -r requirements.txt
```
### 2. 配置环境变量
创建 `.env` 文件并设置 OpenAI API 密钥:
```
OPENAI_API_KEY=your_api_key_here
OPENAI_BASE_URL=https://api.openai.com/v1
```
### 3. 运行分析
#### 方式1使用命令行接口
```bash
# 完全自主分析
python -m src.cli data.csv
# 指定分析需求
python -m src.cli data.csv -r "分析工单健康度"
# 使用模板
python -m src.cli data.csv -t template.md
# 指定输出目录
python -m src.cli data.csv -o results/
# 显示详细日志
python -m src.cli data.csv -v
```
#### 方式2使用 Python API
```python
from src.main import run_analysis
# 运行分析
result = run_analysis(
data_file="data.csv",
user_requirement="分析工单健康度",
output_dir="output"
)
# 检查结果
if result['success']:
print(f"报告路径: {result['report_path']}")
print(f"执行时间: {result['elapsed_time']:.1f}")
else:
print(f"分析失败: {result['error']}")
```
## 系统架构
系统采用五阶段流水线架构:
1. **数据理解Data Understanding**
- 加载和解析 CSV 文件
- 推断数据类型和业务含义
- 评估数据质量
2. **需求理解Requirement Understanding**
- 解析用户的自然语言需求
- 将抽象概念转化为具体指标
- 解析和理解分析模板
3. **分析规划Analysis Planning**
- 根据数据特征和需求生成任务列表
- 确定任务优先级和依赖关系
- 选择合适的分析方法
4. **任务执行Task Execution**
- 使用 ReAct 模式执行任务
- 动态选择和调用工具
- 根据发现调整分析计划
5. **报告生成Report Generation**
- 提炼关键发现
- 组织报告结构
- 生成结论和建议
## 命令行参数
```
usage: python -m src.cli [-h] [-r REQUIREMENT] [-t TEMPLATE] [-o OUTPUT]
[-v] [--no-progress] [--version]
data_file
positional arguments:
data_file 数据文件路径CSV 格式)
optional arguments:
-h, --help 显示帮助信息
-r, --requirement 用户需求(自然语言)
-t, --template 模板文件路径Markdown 格式)
-o, --output 输出目录,默认为 "output"
-v, --verbose 显示详细日志
--no-progress 不显示进度条
--version 显示版本信息
```
## 使用示例
### 示例1完全自主分析
```bash
python -m src.cli cleaned_data.csv
```
系统会自动:
- 识别数据类型(工单、销售、用户等)
- 推断关键字段的业务含义
- 自主决定分析维度
- 生成合理的分析计划
- 执行分析并生成报告
### 示例2指定分析方向
```bash
python -m src.cli cleaned_data.csv -r "我想了解工单的健康度"
```
系统会:
- 理解"健康度"的业务含义
- 将抽象概念转化为具体指标(关闭率、处理效率、积压情况等)
- 根据数据特征选择合适的分析方法
- 生成针对性的报告
### 示例3使用模板
```bash
python -m src.cli cleaned_data.csv -t templates/ticket_analysis.md
```
系统会:
- 理解模板的结构和要求
- 检查数据是否满足模板要求
- 如果数据缺少某些字段,灵活调整
- 按模板结构组织报告
## 输出文件
分析完成后,输出目录包含:
- `analysis_report.md` - 分析报告Markdown 格式)
- `analysis.log` - 执行日志
- `*.png` - 生成的图表(如果有)
## 日志和可观察性
系统提供详细的日志记录:
- **进度显示**:实时显示当前执行阶段和进度
- **AI 思考过程**:显示 AI 的决策过程(使用 `-v` 参数)
- **执行摘要**:显示各阶段的执行时间和状态
- **错误追踪**:详细的错误信息和堆栈跟踪
## 错误处理
系统具有强大的错误处理能力:
- **数据加载错误**:自动尝试多种编码和分隔符
- **AI 调用错误**:重试机制和指数退避
- **工具执行错误**:参数验证和异常捕获
- **任务执行错误**:依赖检查和错误恢复
## 性能指标
- 数据理解阶段:< 30秒
- 完整分析流程:< 30分钟取决于数据大小和任务数量
- 支持最大 100万行数据
## 隐私保护
系统遵循严格的隐私保护原则:
- AI 不能直接访问原始数据
- 只能通过工具获取聚合结果
- 数据画像只包含元数据和统计摘要
- 所有原始数据处理在本地完成
## 故障排除
### 问题1找不到 OpenAI API 密钥
**解决方案**:确保 `.env` 文件存在并包含正确的 API 密钥。
### 问题2数据加载失败
**解决方案**
- 检查文件路径是否正确
- 确保文件是 CSV 格式
- 尝试使用 `-v` 参数查看详细错误信息
### 问题3分析失败
**解决方案**
- 检查日志文件 `output/analysis.log`
- 确保数据文件不为空
- 确保数据格式正确
## 开发和测试
### 运行测试
```bash
# 运行所有测试
pytest
# 运行集成测试
pytest tests/test_integration.py -v
# 运行特定测试
pytest tests/test_integration.py::TestEndToEndAnalysis -v
```
### 代码结构
```
src/
├── main.py # 主流程编排
├── cli.py # 命令行接口
├── logging_config.py # 日志配置
├── data_access.py # 数据访问层
├── error_handling.py # 错误处理
├── engines/ # 分析引擎
│ ├── data_understanding.py
│ ├── requirement_understanding.py
│ ├── analysis_planning.py
│ ├── task_execution.py
│ ├── plan_adjustment.py
│ └── report_generation.py
├── models/ # 数据模型
│ ├── data_profile.py
│ ├── requirement_spec.py
│ ├── analysis_plan.py
│ └── analysis_result.py
└── tools/ # 分析工具
├── base.py
├── query_tools.py
├── stats_tools.py
├── viz_tools.py
└── tool_manager.py
```
## 贡献
欢迎贡献!请遵循以下步骤:
1. Fork 项目
2. 创建特性分支
3. 提交更改
4. 推送到分支
5. 创建 Pull Request
## 许可证
MIT License
## 联系方式
如有问题或建议,请创建 Issue。

Binary file not shown.

View File

@@ -0,0 +1,177 @@
# 工单分析报告
生成时间2026-03-09 10:29:31
数据源cleaned_data.csv
---
# 工单数据分析报告
## 1. 执行摘要
本报告基于84条工单数据质量分数88/100进行全面分析主要发现如下
1. **工单处理效率存在显著异常**平均关闭时长为54.77天但存在2个异常工单处理时长分别为277天和237天占总数的2.38%。"Activation SIM"问题的平均处理时长高达142.5天,远高于其他问题类型。
2. **工单来源渠道集中**邮件渠道占比54.76%46张Telegram bot占比42.86%36张渠道来源仅占2.38%2张显示渠道管理可优化。
3. **问题类型高度集中**远程控制问题占主导66.67%56张前五类问题占总量的87%以上,表明问题分布集中。
4. **车型问题分布不均**EXEED RXT22车型工单最多45.24%38张JAECOO J7T1EJ次之26.19%22张特定车型问题集中。
5. **责任人工作负载差异大**Vsevolod处理工单最多31个但平均处理时长66.68天Vsevolod Tsoi平均处理时长最高152天而何韬处理效率最高平均3.5天)。
## 2. 数据概览
- **数据类型**工单ticket
- **数据规模**84行 × 21列
- **数据质量**88.0/100
- **关键字段**工单号、来源、创建日期、问题类型、问题描述、处理过程、跟踪记录、严重程度、工单状态、模块、责任人、关闭日期、车型、VIN、关闭时长(天)等
- **分析时间范围**2025年1月2日至2025年2月24日
## 3. 详细分析
### 3.1 工单概况分析
工单状态分布显示82.14%69张已关闭17.86%15张临时关闭表明大部分问题已解决。
![工单状态分布](analysis_output\run_20260309_102648\charts\outlier_pie_chart.png)
**来源渠道分析**
- 邮件Mail54.76%46张
- Telegram bot42.86%36张
- Telegram channel2.38%2张
**问题类型分布**前5位
1. 远程控制Remote control66.67%56张
2. 网络问题Network7.14%6张
3. 导航问题Navi5.95%5张
4. 应用问题Application4.76%4张
5. 成员中心认证问题3.57%3张
前五类问题合计占总量的87.09%,显示问题类型高度集中。
### 3.2 工单处理效率分析
**关闭时长统计**
- 平均值54.77天
- 中位数41天
- 标准差48.19天
- 最小值2天
- 最大值277天
- 四分位距IQR26.25天Q25至84.5天Q75
**异常值检测**
- 检测到2个异常工单277天和237天占总数的2.38%
- 异常值上限为171.875天,这两个工单远超此阈值
![关闭时长分布](analysis_output\run_20260309_102648\charts\outlier_bar_chart.png)
**按问题类型分析平均关闭时长**
- Activation SIM142.5天(效率最低)
- Remote control66.5天
- PKI problem47天
- doesn't exist on TSP31.67天
- Network24天
![按问题类型平均关闭时长](analysis_output\run_20260309_102648\charts\avg_close_time_by_issue_type.png)
### 3.3 工单内容与趋势分析
**工单创建时间趋势**2025年1-2月
- 创建高峰期2025年1月13日8个工单
- 创建低谷期2025年1月8日、15日、29日、30日及2月多日仅1个工单
- 整体趋势波动较大,无明显持续上升或下降模式
![工单创建趋势](analysis_output\run_20260309_102648\charts\bar_chart_trend.png)
### 3.4 责任人工作负载分析
**工单处理数量**
- Vsevolod31个最多
- Evgeniy28个
- Kostya5个
- 何韬4个
**平均关闭时长**
- Vsevolod Tsoi152天最高仅处理2个工单
- 林兆国89天
- Vadim69天
- Vsevolod66.68天
- Evgeniy62.39天
- 何韬3.5天(最低,效率最高)
![责任人工作负载](analysis_output\run_20260309_102648\charts\workload_by_responsible.png)
![责任人处理效率](analysis_output\run_20260309_102648\charts\efficiency_by_responsible.png)
### 3.5 车辆相关信息分析
**车型分布**
- EXEED RXT2245.24%38张
- JAECOO J7T1EJ26.19%22张
- EXEED VX FLM36T20.24%17张
- CHERY TIGGO 9 (T28)8.33%7张
**VIN分布**
- LVTDD24B1RG023450和LVTDD24B1RG021245各出现2次表明特定车辆问题可能复发
- 其他VIN均唯一显示问题主要分散于不同车辆
## 4. 结论与建议
### 4.1 主要结论
1. **处理效率需优化**整体平均关闭时长54.77天且存在严重异常值277天和237天"Activation SIM"问题处理时长高达142.5天。
2. **渠道管理可改进**邮件渠道占比过高54.76%Telegram渠道未充分利用。
3. **问题类型集中**远程控制问题占66.67%,需针对性优化。
4. **车型问题分布不均**EXEED RXT22车型问题最集中45.24%)。
5. **责任人效率差异大**Vsevolod Tsoi处理时长152天而何韬仅3.5天,存在明显效率差距。
### 4.2 可操作建议
1. **优化异常工单处理流程**
- 针对处理时长超过100天的工单如Activation SIM问题建立专项处理机制
- 定期审查异常工单,分析根本原因
- 依据异常值检测显示2个工单处理时长分别为277天和237天
2. **平衡渠道分配**
- 推广Telegram bot使用减轻邮件渠道压力
- 依据邮件渠道占比54.76%Telegram bot仅42.86%
3. **加强远程控制问题管理**
- 建立远程控制问题知识库和快速响应机制
- 依据远程控制问题占66.67%56张
4. **针对EXEED RXT22车型专项优化**
- 分析该车型高工单率的原因
- 依据该车型工单占比45.24%38张
5. **提升责任人效率一致性**
- 建立效率标杆如何韬的3.5天平均时长)
- 对处理时长较长的责任人提供培训和支持
- 依据责任人平均处理时长从3.5天到152天不等
6. **建立工单创建趋势监控**
- 监控工单创建高峰期如1月13日的8个工单
- 提前调配资源应对潜在高峰
- 依据:工单创建趋势显示明显波动
通过实施这些建议,预计可显著提升工单处理效率,优化资源分配,并改善客户满意度。
---
## 分析追溯
本报告基于以下分析任务:
- ✓ 工单概况分析:基本分布统计
- 工单总数为84其中82.14%69张已关闭17.86%15张临时关闭表明大部分问题已解决。
- 工单来源中邮件Mail占比最高达54.76%46张Telegram bot占42.86%36张渠道来源仅占2.38%2张
- ✓ 工单处理效率分析:关闭时长统计
- 关闭时长的平均值为54.77天中位数为41天标准差为48.19天,表明数据右偏分布,存在处理时间较长的工单。
- 异常值检测发现2个工单277天和237天处理时间过长占总工单数的2.38%,需重点关注这些异常情况。
- ✓ 工单内容与趋势分析:文本和时间序列
- 2025年1月13日工单创建数量最高达到8个是整体趋势中的峰值。
- 2025年1月8日、15日、29日、30日及2月多日工单创建数量最低仅为1个显示这些日期工单活动较少。
- ✓ 责任人工作负载分析
- Vsevolod 处理工单数量最多31个但平均关闭时长为66.68天,工作量大且周期长。
- Vsevolod Tsoi 平均关闭时长最高152天但仅处理2个工单可能存在效率问题或复杂任务。
- ✓ 车辆相关信息分析车型和VIN分布
- EXEED RXT22车型占比最高达45.24%38/84是问题集中的主要车型。
- JAECOO J7T1EJ车型工单数为22占比26.19%,是第二大问题车型。

Binary file not shown.

After

Width:  |  Height:  |  Size: 50 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 28 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 46 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 24 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 38 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

View File

@@ -1,9 +1,9 @@
{ {
"llm": { "llm": {
"provider": "openai", "provider": "openai",
"api_key": "your_api_key_here", "api_key": "sk-c44i1hy64xgzwox6x08o4zug93frq6rgn84oqugf2pje1tg4",
"base_url": "https://api.openai.com/v1", "base_url": "https://api.xiaomimimo.com/v1",
"model": "gpt-4", "model": "mimo-v2-flash",
"timeout": 120, "timeout": 120,
"max_retries": 3, "max_retries": 3,
"temperature": 0.7, "temperature": 0.7,

306
run_analysis_en.py Normal file
View File

@@ -0,0 +1,306 @@
"""
AI-Driven Data Analysis Framework
==================================
Generic pipeline: works with any CSV file.
AI decides everything at runtime based on data characteristics.
Flow: Load CSV → AI understands metadata → AI plans analysis → AI executes tasks → AI generates report
"""
import sys
import os
sys.path.insert(0, os.path.dirname(__file__))
from src.env_loader import load_env_with_fallback
load_env_with_fallback(['.env'])
import logging
import json
from datetime import datetime
from typing import Optional, List
from openai import OpenAI
from src.engines.ai_data_understanding import ai_understand_data_with_dal
from src.engines.requirement_understanding import understand_requirement
from src.engines.analysis_planning import plan_analysis
from src.engines.task_execution import execute_task
from src.engines.report_generation import generate_report, _convert_chart_paths_in_report
from src.tools.tool_manager import ToolManager
from src.tools.base import _global_registry
from src.models import DataProfile, AnalysisResult
from src.config import get_config
# Register all tools
from src.tools import register_tool
from src.tools.query_tools import (
GetColumnDistributionTool, GetValueCountsTool,
GetTimeSeriesTool, GetCorrelationTool
)
from src.tools.stats_tools import (
CalculateStatisticsTool, PerformGroupbyTool,
DetectOutliersTool, CalculateTrendTool
)
from src.tools.viz_tools import (
CreateBarChartTool, CreateLineChartTool,
CreatePieChartTool, CreateHeatmapTool
)
for tool_cls in [
GetColumnDistributionTool, GetValueCountsTool, GetTimeSeriesTool, GetCorrelationTool,
CalculateStatisticsTool, PerformGroupbyTool, DetectOutliersTool, CalculateTrendTool,
CreateBarChartTool, CreateLineChartTool, CreatePieChartTool, CreateHeatmapTool
]:
register_tool(tool_cls())
logging.basicConfig(level=logging.WARNING)
def run_analysis(
data_file: str,
user_requirement: Optional[str] = None,
template_file: Optional[str] = None,
output_dir: str = "analysis_output"
):
"""
Run the full AI-driven analysis pipeline.
Each run creates a timestamped subdirectory under output_dir:
output_dir/run_20260309_143025/
├── analysis_report.md
└── charts/
├── bar_chart.png
└── ...
Args:
data_file: Path to any CSV file
user_requirement: Natural language requirement (optional)
template_file: Report template path (optional)
output_dir: Base output directory
"""
# 每次运行创建带时间戳的子目录
run_timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
run_dir = os.path.join(output_dir, f"run_{run_timestamp}")
os.makedirs(run_dir, exist_ok=True)
config = get_config()
print("\n" + "=" * 70)
print("AI-Driven Data Analysis")
print("=" * 70)
print(f"Start: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Data: {data_file}")
print(f"Output: {run_dir}")
if template_file:
print(f"Template: {template_file}")
print("=" * 70)
# ── Stage 1: AI Data Understanding ──
print("\n[1/5] AI Understanding Data...")
print(" (AI sees only metadata — never raw rows)")
profile, dal = ai_understand_data_with_dal(data_file)
print(f" Type: {profile.inferred_type}")
print(f" Quality: {profile.quality_score}/100")
print(f" Columns: {profile.column_count}, Rows: {profile.row_count}")
print(f" Summary: {profile.summary[:120]}...")
# ── Stage 2: Requirement Understanding ──
print("\n[2/5] Understanding Requirements...")
requirement = understand_requirement(
user_input=user_requirement or "对数据进行全面分析",
data_profile=profile,
template_path=template_file
)
print(f" Objectives: {len(requirement.objectives)}")
for obj in requirement.objectives:
print(f" - {obj.name} (priority: {obj.priority})")
# ── Stage 3: AI Analysis Planning ──
print("\n[3/5] AI Planning Analysis...")
tool_manager = ToolManager(_global_registry)
tools = tool_manager.select_tools(profile)
print(f" Available tools: {len(tools)}")
analysis_plan = plan_analysis(profile, requirement, available_tools=tools)
print(f" Tasks planned: {len(analysis_plan.tasks)}")
for task in sorted(analysis_plan.tasks, key=lambda t: t.priority, reverse=True):
print(f" [{task.priority}] {task.name}")
if task.required_tools:
print(f" tools: {', '.join(task.required_tools)}")
# ── Stage 4: AI Task Execution ──
print("\n[4/5] AI Executing Tasks...")
# Reuse DAL from Stage 1 — no need to load data again
dal.set_output_dir(run_dir)
results: List[AnalysisResult] = []
sorted_tasks = sorted(analysis_plan.tasks, key=lambda t: t.priority, reverse=True)
for i, task in enumerate(sorted_tasks, 1):
print(f"\n Task {i}/{len(sorted_tasks)}: {task.name}")
result = execute_task(task, tools, dal)
results.append(result)
if result.success:
print(f" ✓ Done ({result.execution_time:.1f}s), insights: {len(result.insights)}")
for insight in result.insights[:2]:
print(f" - {insight[:80]}")
else:
print(f" ✗ Failed: {result.error}")
successful = sum(1 for r in results if r.success)
print(f"\n Results: {successful}/{len(results)} tasks succeeded")
# ── Stage 5: Report Generation ──
print("\n[5/5] Generating Report...")
report_path = os.path.join(run_dir, "analysis_report.md")
if template_file and os.path.exists(template_file):
report = _generate_template_report(profile, results, template_file, config, run_dir)
else:
report = generate_report(results, requirement, profile, output_path=run_dir)
# Save report — convert chart paths to relative (./charts/xxx.png)
report = _convert_chart_paths_in_report(report, run_dir)
with open(report_path, 'w', encoding='utf-8') as f:
f.write(report)
print(f" Report saved: {report_path}")
print(f" Report length: {len(report)} chars")
# ── Done ──
print("\n" + "=" * 70)
print("Analysis Complete!")
print(f"End: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Output: {run_dir}")
print("=" * 70)
return True
def _generate_template_report(
profile: DataProfile,
results: List[AnalysisResult],
template_path: str,
config,
run_dir: str = ""
) -> str:
"""Use AI to fill a template with data from task execution results."""
client = OpenAI(api_key=config.llm.api_key, base_url=config.llm.base_url)
with open(template_path, 'r', encoding='utf-8') as f:
template = f.read()
# Collect all data from task results
all_data = {}
all_insights = []
for r in results:
if r.success:
all_data[r.task_name] = {
'data': json.dumps(r.data, ensure_ascii=False, default=str)[:1000],
'insights': r.insights
}
all_insights.extend(r.insights)
data_json = json.dumps(all_data, ensure_ascii=False, indent=1)
if len(data_json) > 12000:
data_json = data_json[:12000] + "\n... (truncated)"
prompt = f"""你是一位专业的数据分析师。请根据以下分析结果,按照模板格式生成完整的报告。
## 数据概况
- 类型: {profile.inferred_type}
- 行数: {profile.row_count}, 列数: {profile.column_count}
- 质量: {profile.quality_score}/100
- 摘要: {profile.summary[:500]}
## 关键字段
{json.dumps(profile.key_fields, ensure_ascii=False, indent=2)}
## 分析结果
{data_json}
## 关键洞察
{chr(10).join(f"- {i}" for i in all_insights[:20])}
## 报告模板
```markdown
{template}
```
## 图表文件
以下是分析过程中生成的图表文件,请在报告适当位置嵌入:
{_collect_chart_paths(results, run_dir)}
## 要求
1. 用实际数据填充模板中所有占位符
2. 根据数据中的字段,智能映射到模板分类
3. 所有数字必须来自分析结果,不要编造
4. 如果某个模板分类在数据中没有对应,标注"本期无数据"
5. 保持Markdown格式
6. 在报告中嵌入图表,使用 ![描述](图表路径) 格式,让报告图文结合
"""
print(" AI filling template with analysis results...")
response = client.chat.completions.create(
model=config.llm.model,
messages=[
{"role": "system", "content": "你是数据分析专家。根据分析结果填充报告模板所有数字必须来自真实数据。输出纯Markdown。"},
{"role": "user", "content": prompt}
],
temperature=0.3,
max_tokens=4000
)
report = response.choices[0].message.content
header = f"""<!--
Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
Data: {profile.file_path} ({profile.row_count} rows x {profile.column_count} cols)
Quality: {profile.quality_score}/100
Template: {template_path}
AI never accessed raw data rows — only aggregated tool results
-->
"""
return header + report
def _collect_chart_paths(results: List[AnalysisResult], run_dir: str = "") -> str:
"""Collect all chart paths from task results for embedding in reports.
Returns paths relative to run_dir (e.g. ./charts/bar_chart.png)."""
paths = []
for r in results:
if not r.success:
continue
# From visualizations list
for viz in (r.visualizations or []):
if viz and viz not in paths:
paths.append(viz)
# From data dict (chart_path in tool results)
if isinstance(r.data, dict):
for key, val in r.data.items():
if isinstance(val, dict) and val.get('chart_path'):
cp = val['chart_path']
if cp not in paths:
paths.append(cp)
if not paths:
return "(无图表)"
# Convert to relative paths
from src.engines.report_generation import _to_relative_chart_path
rel_paths = [_to_relative_chart_path(p, run_dir) for p in paths]
return "\n".join(f"- {p}" for p in rel_paths)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="AI-Driven Data Analysis")
parser.add_argument("--data", default="cleaned_data.csv", help="CSV file path")
parser.add_argument("--requirement", default=None, help="Analysis requirement (natural language)")
parser.add_argument("--template", default=None, help="Report template path")
parser.add_argument("--output", default="analysis_output", help="Output directory")
args = parser.parse_args()
success = run_analysis(
data_file=args.data,
user_requirement=args.requirement,
template_file=args.template,
output_dir=args.output
)
sys.exit(0 if success else 1)

View File

@@ -1,44 +0,0 @@
# AI Data Analysis Agent - Source Code
## Project Structure
```
src/
├── __init__.py # Package initialization
├── models/ # Core data models
│ ├── __init__.py
│ ├── data_profile.py # DataProfile and ColumnInfo models
│ ├── requirement_spec.py # RequirementSpec and AnalysisObjective models
│ ├── analysis_plan.py # AnalysisPlan and AnalysisTask models
│ └── analysis_result.py # AnalysisResult model
├── engines/ # Analysis engines (to be implemented)
│ └── __init__.py
└── tools/ # Analysis tools (to be implemented)
└── __init__.py
```
## Core Data Models
### DataProfile
Represents the profile of a dataset including metadata, column information, and quality metrics.
### RequirementSpec
Specification of user requirements including objectives, constraints, and expected outputs.
### AnalysisPlan
Complete analysis plan with tasks, dependencies, and tool configuration.
### AnalysisResult
Result of executing an analysis task including data, visualizations, and insights.
## Testing
All models support:
- Dictionary serialization (`to_dict()`, `from_dict()`)
- JSON serialization (`to_json()`, `from_json()`)
- Full test coverage in `tests/test_models.py`
Run tests with:
```bash
pytest tests/test_models.py -v
```

Binary file not shown.

View File

@@ -35,6 +35,7 @@ class DataAccessLayer:
""" """
self._data = data # 私有数据AI 不可访问 self._data = data # 私有数据AI 不可访问
self._file_path = file_path self._file_path = file_path
self._output_dir = "" # 输出目录,用于图表等文件
@classmethod @classmethod
def load_from_file(cls, file_path: str, max_retries: int = 3, optimize_memory: bool = True) -> 'DataAccessLayer': def load_from_file(cls, file_path: str, max_retries: int = 3, optimize_memory: bool = True) -> 'DataAccessLayer':
@@ -170,7 +171,25 @@ class DataAccessLayer:
# 尝试转换为日期时间 # 尝试转换为日期时间
if col_data.dtype == 'object': if col_data.dtype == 'object':
try: try:
pd.to_datetime(col_data.dropna().head(100)) sample = col_data.dropna().head(20)
if len(sample) == 0:
pass
else:
# 尝试用常见日期格式解析
date_formats = ['%Y-%m-%d', '%Y/%m/%d', '%Y-%m-%d %H:%M:%S', '%Y/%m/%d %H:%M:%S', '%d/%m/%Y', '%m/%d/%Y']
parsed = False
for fmt in date_formats:
try:
pd.to_datetime(sample, format=fmt)
parsed = True
break
except (ValueError, TypeError):
continue
if not parsed:
# 最后尝试自动推断,但用 infer_datetime_format
pd.to_datetime(sample, format='mixed', dayfirst=False)
parsed = True
if parsed:
return 'datetime' return 'datetime'
except: except:
pass pass
@@ -187,10 +206,21 @@ class DataAccessLayer:
# 默认为文本类型 # 默认为文本类型
return 'text' return 'text'
def set_output_dir(self, output_dir: str):
"""
设置输出目录,图表等文件将保存到此目录下。
参数:
output_dir: 输出目录路径
"""
self._output_dir = output_dir
def execute_tool(self, tool: Any, **kwargs) -> Dict[str, Any]: def execute_tool(self, tool: Any, **kwargs) -> Dict[str, Any]:
""" """
执行工具并返回聚合结果(安全)。 执行工具并返回聚合结果(安全)。
如果设置了 output_dir图表文件会自动保存到 output_dir/charts/ 下。
参数: 参数:
tool: 分析工具实例 tool: 分析工具实例
**kwargs: 工具参数 **kwargs: 工具参数
@@ -199,6 +229,10 @@ class DataAccessLayer:
工具执行结果(聚合数据) 工具执行结果(聚合数据)
""" """
try: try:
# 如果设置了输出目录,自动修正图表输出路径
if self._output_dir:
kwargs = self._fix_output_path(tool, kwargs)
result = tool.execute(self._data, **kwargs) result = tool.execute(self._data, **kwargs)
return self._sanitize_result(result) return self._sanitize_result(result)
except Exception as e: except Exception as e:
@@ -209,6 +243,37 @@ class DataAccessLayer:
'tool': tool.name 'tool': tool.name
} }
def _fix_output_path(self, tool: Any, kwargs: dict) -> dict:
"""
确保图表输出路径指向 output_dir/charts/ 目录。
参数:
tool: 工具实例
kwargs: 工具参数
返回:
修正后的参数
"""
# 检查工具是否有 output_path 参数
props = getattr(tool, 'parameters', {}).get('properties', {})
if 'output_path' not in props:
return kwargs
charts_dir = str(Path(self._output_dir) / "charts")
if 'output_path' in kwargs:
# AI 指定了路径,但可能是相对路径如 "bar_chart.png"
output_path = kwargs['output_path']
if not Path(output_path).is_absolute() and not output_path.startswith(self._output_dir):
kwargs['output_path'] = str(Path(charts_dir) / Path(output_path).name)
else:
# AI 没指定路径,使用默认值但重定向到 charts 目录
default_path = props['output_path'].get('default', '')
if default_path:
kwargs['output_path'] = str(Path(charts_dir) / default_path)
return kwargs
def _sanitize_result(self, result: Dict[str, Any]) -> Dict[str, Any]: def _sanitize_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
""" """
确保结果不包含原始数据,只返回聚合数据。 确保结果不包含原始数据,只返回聚合数据。

View File

@@ -0,0 +1,221 @@
"""
真正的 AI 驱动数据理解引擎
AI 只能看到表头和统计摘要,通过推理理解数据
"""
import logging
from typing import Dict, Any, List
import json
from openai import OpenAI
from src.models import DataProfile, ColumnInfo
from src.config import get_config
from src.data_access import DataAccessLayer
logger = logging.getLogger(__name__)
def ai_understand_data(data_file: str) -> DataProfile:
"""
使用 AI 理解数据(只基于元数据,不看原始数据)
参数:
data_file: 数据文件路径
返回:
数据画像
"""
profile, _ = ai_understand_data_with_dal(data_file)
return profile
def ai_understand_data_with_dal(data_file: str):
"""
使用 AI 理解数据,同时返回 DataAccessLayer 以避免重复加载。
参数:
data_file: 数据文件路径
返回:
(DataProfile, DataAccessLayer) 元组
"""
# 1. 加载数据AI 不可见)
logger.info(f"加载数据: {data_file}")
dal = DataAccessLayer.load_from_file(data_file)
# 2. 生成数据画像(元数据)
logger.info("生成数据画像(元数据)")
profile = dal.get_profile()
# 3. 准备给 AI 的信息(只有元数据)
metadata = _prepare_metadata_for_ai(profile)
# 4. 调用 AI 分析
logger.info("调用 AI 分析数据特征...")
ai_analysis = _call_ai_for_analysis(metadata)
# 5. 更新数据画像
profile.inferred_type = ai_analysis.get('data_type', 'unknown')
profile.key_fields = ai_analysis.get('key_fields', {})
profile.quality_score = ai_analysis.get('quality_score', 0.0)
profile.summary = ai_analysis.get('summary', '')
return profile, dal
def _prepare_metadata_for_ai(profile: DataProfile) -> Dict[str, Any]:
"""
准备给 AI 的元数据(不包含原始数据)
参数:
profile: 数据画像
返回:
元数据字典
"""
metadata = {
"file_path": profile.file_path,
"row_count": profile.row_count,
"column_count": profile.column_count,
"columns": []
}
# 只提供列的元信息
for col in profile.columns:
col_info = {
"name": col.name,
"dtype": col.dtype,
"missing_rate": col.missing_rate,
"unique_count": col.unique_count,
"sample_values": col.sample_values[:5] # 最多5个示例值
}
# 如果有统计信息,也提供
if col.statistics:
col_info["statistics"] = col.statistics
metadata["columns"].append(col_info)
return metadata
def _call_ai_for_analysis(metadata: Dict[str, Any]) -> Dict[str, Any]:
"""
调用 AI 分析数据特征
参数:
metadata: 数据元信息
返回:
AI 分析结果
"""
config = get_config()
# 创建 OpenAI 客户端
client = OpenAI(
api_key=config.llm.api_key,
base_url=config.llm.base_url
)
# 构建提示词
prompt = f"""你是一个数据分析专家。我会给你一个数据集的元信息(表头、统计摘要),你需要分析这个数据集。
重要:你只能看到元信息,看不到原始数据行。请基于列名、数据类型、统计特征进行推理。
数据元信息:
```json
{json.dumps(metadata, ensure_ascii=False, indent=2)}
```
请分析并回答以下问题:
1. 这是什么类型的数据?(工单数据/销售数据/用户数据/其他)
2. 哪些是关键字段?每个字段的业务含义是什么?
3. 数据质量如何0-100分
4. 用一段话总结这个数据集的特征
请以 JSON 格式返回结果:
{{
"data_type": "ticket/sales/user/other",
"key_fields": {{
"字段名1": "业务含义1",
"字段名2": "业务含义2"
}},
"quality_score": 85.5,
"summary": "数据集的总结描述"
}}
"""
try:
# 调用 AI
response = client.chat.completions.create(
model=config.llm.model,
messages=[
{"role": "system", "content": "你是一个数据分析专家,擅长从元数据推断数据特征。"},
{"role": "user", "content": prompt}
],
temperature=0.3,
max_tokens=2000
)
# 解析响应
content = response.choices[0].message.content
logger.info(f"AI 响应: {content[:200]}...")
# 尝试提取 JSON
result = _extract_json_from_response(content)
return result
except Exception as e:
logger.error(f"AI 调用失败: {e}")
# 返回默认值
return {
"data_type": "unknown",
"key_fields": {},
"quality_score": 0.0,
"summary": f"AI 分析失败: {str(e)}"
}
def _extract_json_from_response(content: str) -> Dict[str, Any]:
"""
从 AI 响应中提取 JSON
参数:
content: AI 响应内容
返回:
解析后的 JSON 字典
"""
# 尝试直接解析
try:
return json.loads(content)
except:
pass
# 尝试提取 JSON 代码块
import re
json_match = re.search(r'```json\s*(.*?)\s*```', content, re.DOTALL)
if json_match:
try:
return json.loads(json_match.group(1))
except:
pass
# 尝试提取 {} 内容
json_match = re.search(r'\{.*\}', content, re.DOTALL)
if json_match:
try:
return json.loads(json_match.group(0))
except:
pass
# 如果都失败,返回默认值
logger.warning("无法从 AI 响应中提取 JSON使用默认值")
return {
"data_type": "unknown",
"key_fields": {},
"quality_score": 0.0,
"summary": content[:500]
}

View File

@@ -1,4 +1,8 @@
"""Analysis planning engine for generating dynamic analysis plans.""" """AI-driven analysis planning engine.
AI generates specific, tool-aware tasks based on actual data characteristics.
No hardcoded rules about column names or data types.
"""
import os import os
import json import json
@@ -10,68 +14,62 @@ from openai import OpenAI
from src.models.data_profile import DataProfile from src.models.data_profile import DataProfile
from src.models.requirement_spec import RequirementSpec, AnalysisObjective from src.models.requirement_spec import RequirementSpec, AnalysisObjective
from src.models.analysis_plan import AnalysisPlan, AnalysisTask from src.models.analysis_plan import AnalysisPlan, AnalysisTask
from src.tools.base import AnalysisTool
def plan_analysis( def plan_analysis(
data_profile: DataProfile, data_profile: DataProfile,
requirement: RequirementSpec requirement: RequirementSpec,
available_tools: List[AnalysisTool] = None
) -> AnalysisPlan: ) -> AnalysisPlan:
""" """
AI-driven analysis planning. AI-driven analysis planning.
Generates dynamic task list based on data features and requirements. AI sees the data profile (column names, types, stats, sample values)
and available tools, then generates a concrete task list with specific
Args: tool calls and parameters tailored to this dataset.
data_profile: Profile of the data to be analyzed
requirement: Parsed requirement specification
Returns:
AnalysisPlan with task list and configuration
Requirements: FR-3.1, FR-3.2
""" """
# Get API key from environment from src.config import get_config
api_key = os.getenv('OPENAI_API_KEY') config = get_config()
api_key = config.llm.api_key
if not api_key: if not api_key:
# Fallback to rule-based planning return _fallback_planning(data_profile, requirement)
return _fallback_analysis_planning(data_profile, requirement)
client = OpenAI(api_key=api_key) client = OpenAI(api_key=api_key, base_url=config.llm.base_url)
prompt = _build_planning_prompt(data_profile, requirement, available_tools)
# Build prompt for AI
prompt = _build_planning_prompt(data_profile, requirement)
try: try:
# Call LLM
response = client.chat.completions.create( response = client.chat.completions.create(
model="gpt-4", model=config.llm.model,
messages=[ messages=[
{"role": "system", "content": "You are a data analysis expert who creates comprehensive analysis plans based on data characteristics and user requirements."}, {"role": "system", "content": (
"You are a data analysis planning expert. "
"Given data metadata and available tools, create a concrete analysis plan. "
"Each task should specify exactly which tools to call and with what column names. "
"Respond in JSON only."
)},
{"role": "user", "content": prompt} {"role": "user", "content": prompt}
], ],
temperature=0.7, temperature=0.5,
max_tokens=3000 max_tokens=3000
) )
# Parse AI response
ai_plan = _parse_planning_response(response.choices[0].message.content) ai_plan = _parse_planning_response(response.choices[0].message.content)
# Create tasks from AI plan
tasks = [] tasks = []
for i, task_data in enumerate(ai_plan.get('tasks', [])): for i, td in enumerate(ai_plan.get('tasks', [])):
task = AnalysisTask( tasks.append(AnalysisTask(
id=task_data.get('id', f"task_{i+1}"), id=td.get('id', f"task_{i+1}"),
name=task_data.get('name', f"Task {i+1}"), name=td.get('name', f"Task {i+1}"),
description=task_data.get('description', ''), description=td.get('description', ''),
priority=task_data.get('priority', 3), priority=td.get('priority', 3),
dependencies=task_data.get('dependencies', []), dependencies=td.get('dependencies', []),
required_tools=task_data.get('required_tools', []), required_tools=td.get('required_tools', []),
expected_output=task_data.get('expected_output', ''), expected_output=td.get('expected_output', ''),
status='pending' status='pending'
) ))
tasks.append(task)
# Validate dependencies
tasks = _ensure_valid_dependencies(tasks) tasks = _ensure_valid_dependencies(tasks)
return AnalysisPlan( return AnalysisPlan(
@@ -84,69 +82,104 @@ def plan_analysis(
) )
except Exception as e: except Exception as e:
# Fallback to rule-based if AI fails return _fallback_planning(data_profile, requirement)
return _fallback_analysis_planning(data_profile, requirement)
def _build_planning_prompt( def _build_planning_prompt(
data_profile: DataProfile, data_profile: DataProfile,
requirement: RequirementSpec requirement: RequirementSpec,
available_tools: List[AnalysisTool] = None
) -> str: ) -> str:
"""Build prompt for AI planning.""" """Build prompt with full data context and tool catalog."""
column_names = [col.name for col in data_profile.columns] # Column details
column_types = {col.name: col.dtype for col in data_profile.columns} col_details = []
for col in data_profile.columns:
detail = f" - {col.name} (type: {col.dtype}, missing: {col.missing_rate:.1%}, unique: {col.unique_count})"
if col.sample_values:
samples = [str(v) for v in col.sample_values[:3]]
detail += f"\n samples: {', '.join(samples)}"
if col.statistics:
stats_str = json.dumps(col.statistics, ensure_ascii=False, default=str)[:200]
detail += f"\n stats: {stats_str}"
col_details.append(detail)
columns_section = "\n".join(col_details)
# Tool catalog
tools_section = ""
if available_tools:
tool_descs = []
for t in available_tools:
params = json.dumps(t.parameters.get('properties', {}), ensure_ascii=False)
required = t.parameters.get('required', [])
tool_descs.append(f" - {t.name}: {t.description}\n params: {params}\n required: {required}")
tools_section = "\nAvailable Tools:\n" + "\n".join(tool_descs)
# Objectives
objectives_str = "\n".join([ objectives_str = "\n".join([
f"- {obj.name}: {obj.description} (Priority: {obj.priority})" f" - {obj.name}: {obj.description} (priority: {obj.priority})"
for obj in requirement.objectives for obj in requirement.objectives
]) ])
prompt = f"""Create a comprehensive analysis plan based on the following: return f"""Create an analysis plan for this dataset.
Data Characteristics: Data Profile:
- Type: {data_profile.inferred_type} - Type: {data_profile.inferred_type}
- Rows: {data_profile.row_count} - Rows: {data_profile.row_count}, Columns: {data_profile.column_count}
- Columns: {column_names} - Quality: {data_profile.quality_score}/100
- Column Types: {column_types} - Summary: {data_profile.summary[:300]}
- Key Fields: {data_profile.key_fields}
- Quality Score: {data_profile.quality_score} Columns:
{columns_section}
Key Fields: {json.dumps(data_profile.key_fields, ensure_ascii=False)}
{tools_section}
User Requirement: {requirement.user_input}
Analysis Objectives: Analysis Objectives:
{objectives_str} {objectives_str}
Please generate an analysis plan with the following structure (return as JSON): Generate a JSON plan. Each task should reference ACTUAL column names from the data
and specify which tools to use. The AI executor will call these tools at runtime.
{{ {{
"tasks": [ "tasks": [
{{ {{
"id": "task_1", "id": "task_1",
"name": "Task name", "name": "Task name (Chinese OK)",
"description": "Detailed description", "description": "Detailed description including which columns to analyze and how. Be specific about tool parameters.",
"priority": 5, "priority": 5,
"dependencies": [], "dependencies": [],
"required_tools": ["tool1", "tool2"], "required_tools": ["tool_name1", "tool_name2"],
"expected_output": "What this task should produce" "expected_output": "What this task should produce"
}} }}
], ],
"tool_config": {{}},
"estimated_duration": 300 "estimated_duration": 300
}} }}
Guidelines: Rules:
1. Tasks should be specific and executable 1. Use ACTUAL column names from the data profile above
2. Priority: 1-5 (5 is highest) 2. Each task description should be specific enough for an AI executor to know exactly what to do
3. High-priority objectives should have high-priority tasks 3. Generate 3-8 tasks depending on data complexity
4. Include dependencies between tasks (use task IDs) 4. Higher priority objectives get higher priority tasks
5. Suggest appropriate tools for each task 5. Include distribution, groupby, statistics, trend, and visualization tasks as appropriate
6. Estimate total duration in seconds 6. Don't assume column semantics — use what the data profile tells you
7. Generate 3-8 tasks depending on complexity
""" """
return prompt
def _parse_planning_response(response_text: str) -> Dict[str, Any]: def _parse_planning_response(response_text: str) -> Dict[str, Any]:
"""Parse AI planning response into structured format.""" """Parse AI planning response."""
# Try to extract JSON from response # Try JSON code block first
json_block = re.search(r'```json\s*(.*?)\s*```', response_text, re.DOTALL)
if json_block:
try:
return json.loads(json_block.group(1))
except json.JSONDecodeError:
pass
# Try raw JSON
json_match = re.search(r'\{.*\}', response_text, re.DOTALL) json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
if json_match: if json_match:
try: try:
@@ -154,153 +187,139 @@ def _parse_planning_response(response_text: str) -> Dict[str, Any]:
except json.JSONDecodeError: except json.JSONDecodeError:
pass pass
# Fallback: return default structure return {'tasks': [], 'estimated_duration': 0}
return {
'tasks': [],
'tool_config': {},
'estimated_duration': 0
}
def _ensure_valid_dependencies(tasks: List[AnalysisTask]) -> List[AnalysisTask]: def _ensure_valid_dependencies(tasks: List[AnalysisTask]) -> List[AnalysisTask]:
"""Ensure all task dependencies are valid (no cycles, all exist).""" """Ensure all task dependencies are valid."""
task_ids = {task.id for task in tasks} task_ids = {task.id for task in tasks}
# Remove invalid dependencies
for task in tasks: for task in tasks:
task.dependencies = [dep for dep in task.dependencies if dep in task_ids and dep != task.id] task.dependencies = [d for d in task.dependencies if d in task_ids and d != task.id]
# Check for cycles and remove if found
if _has_circular_dependency(tasks): if _has_circular_dependency(tasks):
# Simple fix: remove all dependencies
for task in tasks: for task in tasks:
task.dependencies = [] task.dependencies = []
return tasks return tasks
def _fallback_analysis_planning( def _has_circular_dependency(tasks: List[AnalysisTask]) -> bool:
"""Check for circular dependencies using DFS."""
graph = {task.id: task.dependencies for task in tasks}
visited = set()
rec_stack = set()
def dfs(node):
visited.add(node)
rec_stack.add(node)
for neighbor in graph.get(node, []):
if neighbor not in visited:
if dfs(neighbor):
return True
elif neighbor in rec_stack:
return True
rec_stack.remove(node)
return False
for task_id in graph:
if task_id not in visited:
if dfs(task_id):
return True
return False
def _fallback_planning(
data_profile: DataProfile, data_profile: DataProfile,
requirement: RequirementSpec requirement: RequirementSpec
) -> AnalysisPlan: ) -> AnalysisPlan:
""" """Generic fallback planning — no hardcoded column names."""
Rule-based fallback for analysis planning.
Used when AI is unavailable or fails.
"""
tasks = [] tasks = []
task_id = 1 task_id = 1
# Generate tasks based on objectives # Task 1: Distribution analysis for categorical columns
for objective in requirement.objectives: cat_cols = [c for c in data_profile.columns if c.dtype == 'categorical']
# Basic statistics task if cat_cols:
if any(keyword in objective.name.lower() for keyword in ['统计', 'statistics', '概览', 'overview']): col_names = [c.name for c in cat_cols[:3]]
tasks.append(AnalysisTask( tasks.append(AnalysisTask(
id=f"task_{task_id}", id=f"task_{task_id}",
name=f"计算基础统计 - {objective.name}", name="分类字段分布分析",
description=f"计算与{objective.name}相关的基础统计指标", description=f"Analyze distribution of categorical columns: {', '.join(col_names)}",
priority=objective.priority, priority=4,
dependencies=[], required_tools=['get_column_distribution', 'get_value_counts'],
required_tools=['calculate_statistics'], expected_output="Distribution statistics for key categorical fields",
expected_output="统计摘要",
status='pending' status='pending'
)) ))
task_id += 1 task_id += 1
# Distribution analysis # Task 2: Numeric statistics
if any(keyword in objective.name.lower() for keyword in ['分布', 'distribution']): num_cols = [c for c in data_profile.columns if c.dtype == 'numeric']
if num_cols:
col_names = [c.name for c in num_cols[:3]]
tasks.append(AnalysisTask( tasks.append(AnalysisTask(
id=f"task_{task_id}", id=f"task_{task_id}",
name=f"分布分析 - {objective.name}", name="数值字段统计分析",
description=f"分析{objective.name}的分布特征", description=f"Calculate statistics for numeric columns: {', '.join(col_names)}",
priority=objective.priority, priority=4,
dependencies=[],
required_tools=['get_value_counts', 'create_bar_chart'],
expected_output="分布图表和统计",
status='pending'
))
task_id += 1
# Trend analysis
if any(keyword in objective.name.lower() for keyword in ['趋势', 'trend', '时间', 'time']):
tasks.append(AnalysisTask(
id=f"task_{task_id}",
name=f"趋势分析 - {objective.name}",
description=f"分析{objective.name}的时间趋势",
priority=objective.priority,
dependencies=[],
required_tools=['get_time_series', 'calculate_trend', 'create_line_chart'],
expected_output="趋势图表和分析",
status='pending'
))
task_id += 1
# Health/quality analysis
if any(keyword in objective.name.lower() for keyword in ['健康', 'health', '质量', 'quality']):
tasks.append(AnalysisTask(
id=f"task_{task_id}",
name=f"质量评估 - {objective.name}",
description=f"评估{objective.name}相关的数据质量",
priority=objective.priority,
dependencies=[],
required_tools=['calculate_statistics', 'detect_outliers'], required_tools=['calculate_statistics', 'detect_outliers'],
expected_output="质量评分和问题识别", expected_output="Descriptive statistics and outlier detection",
status='pending'
))
task_id += 1
# Task 3: Time series if datetime columns exist
dt_cols = [c for c in data_profile.columns if c.dtype == 'datetime']
if dt_cols:
tasks.append(AnalysisTask(
id=f"task_{task_id}",
name="时间趋势分析",
description=f"Analyze time trends using column: {dt_cols[0].name}",
priority=3,
required_tools=['get_time_series', 'calculate_trend'],
expected_output="Time series trends",
status='pending'
))
task_id += 1
# Task 4: Groupby analysis
if cat_cols and num_cols:
tasks.append(AnalysisTask(
id=f"task_{task_id}",
name="分组聚合分析",
description=f"Group by {cat_cols[0].name} and aggregate {num_cols[0].name}",
priority=3,
required_tools=['perform_groupby'],
expected_output="Grouped aggregation results",
status='pending' status='pending'
)) ))
task_id += 1 task_id += 1
# If no tasks generated, create default task
if not tasks: if not tasks:
tasks.append(AnalysisTask( tasks.append(AnalysisTask(
id="task_1", id="task_1",
name="综合数据分析", name="综合数据分析",
description="对数据进行全面的探索性分析", description="Perform exploratory analysis on the dataset",
priority=3, priority=3,
dependencies=[], required_tools=['get_column_distribution', 'calculate_statistics'],
required_tools=['calculate_statistics', 'get_value_counts'], expected_output="Basic data analysis",
expected_output="数据分析报告",
status='pending' status='pending'
)) ))
return AnalysisPlan( return AnalysisPlan(
objectives=requirement.objectives, objectives=requirement.objectives,
tasks=tasks, tasks=tasks,
tool_config={}, estimated_duration=len(tasks) * 60,
estimated_duration=len(tasks) * 60, # 60 seconds per task
created_at=datetime.now(), created_at=datetime.now(),
updated_at=datetime.now() updated_at=datetime.now()
) )
def validate_task_dependencies(tasks: List[AnalysisTask]) -> Dict[str, Any]: def validate_task_dependencies(tasks: List[AnalysisTask]) -> Dict[str, Any]:
""" """Validate task dependencies."""
Validate task dependencies.
Checks:
1. All dependencies exist
2. No circular dependencies (forms DAG)
Args:
tasks: List of analysis tasks
Returns:
Dictionary with validation results
Requirements: FR-3.1
"""
task_ids = {task.id for task in tasks} task_ids = {task.id for task in tasks}
# Check if all dependencies exist
missing_deps = [] missing_deps = []
for task in tasks: for task in tasks:
for dep_id in task.dependencies: for dep_id in task.dependencies:
if dep_id not in task_ids: if dep_id not in task_ids:
missing_deps.append({ missing_deps.append({'task_id': task.id, 'missing_dep': dep_id})
'task_id': task.id,
'missing_dep': dep_id
})
# Check for circular dependencies
has_cycle = _has_circular_dependency(tasks) has_cycle = _has_circular_dependency(tasks)
return { return {
@@ -309,36 +328,3 @@ def validate_task_dependencies(tasks: List[AnalysisTask]) -> Dict[str, Any]:
'has_circular_dependency': has_cycle, 'has_circular_dependency': has_cycle,
'forms_dag': not has_cycle 'forms_dag': not has_cycle
} }
def _has_circular_dependency(tasks: List[AnalysisTask]) -> bool:
"""Check if task dependencies form a cycle using DFS."""
# Build adjacency list
graph = {task.id: task.dependencies for task in tasks}
# Track visited nodes
visited = set()
rec_stack = set()
def has_cycle_util(node: str) -> bool:
visited.add(node)
rec_stack.add(node)
# Check all neighbors
for neighbor in graph.get(node, []):
if neighbor not in visited:
if has_cycle_util(neighbor):
return True
elif neighbor in rec_stack:
return True
rec_stack.remove(node)
return False
# Check each node
for task_id in graph:
if task_id not in visited:
if has_cycle_util(task_id):
return True
return False

View File

@@ -9,6 +9,7 @@ from openai import OpenAI
from src.models.analysis_plan import AnalysisPlan, AnalysisTask from src.models.analysis_plan import AnalysisPlan, AnalysisTask
from src.models.analysis_result import AnalysisResult from src.models.analysis_result import AnalysisResult
from src.config import get_config
def adjust_plan( def adjust_plan(
@@ -30,13 +31,14 @@ def adjust_plan(
Requirements: FR-3.3, FR-5.4 Requirements: FR-3.3, FR-5.4
""" """
# Get API key # Get config
api_key = os.getenv('OPENAI_API_KEY') config = get_config()
api_key = config.llm.api_key
if not api_key: if not api_key:
# Fallback to rule-based adjustment # Fallback to rule-based adjustment
return _fallback_plan_adjustment(plan, completed_results) return _fallback_plan_adjustment(plan, completed_results)
client = OpenAI(api_key=api_key) client = OpenAI(api_key=api_key, base_url=config.llm.base_url)
# Build prompt for AI # Build prompt for AI
prompt = _build_adjustment_prompt(plan, completed_results) prompt = _build_adjustment_prompt(plan, completed_results)
@@ -44,7 +46,7 @@ def adjust_plan(
try: try:
# Call LLM # Call LLM
response = client.chat.completions.create( response = client.chat.completions.create(
model="gpt-4", model=config.llm.model,
messages=[ messages=[
{"role": "system", "content": "You are a data analysis expert who adjusts analysis plans based on findings."}, {"role": "system", "content": "You are a data analysis expert who adjusts analysis plans based on findings."},
{"role": "user", "content": prompt} {"role": "user", "content": prompt}

View File

@@ -4,8 +4,11 @@
""" """
import os import os
import re
import json
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
from datetime import datetime from datetime import datetime
from pathlib import PurePosixPath, Path
from src.models.analysis_result import AnalysisResult from src.models.analysis_result import AnalysisResult
from src.models.requirement_spec import RequirementSpec from src.models.requirement_spec import RequirementSpec
@@ -309,6 +312,56 @@ def _generate_conclusion_summary(key_findings: List[Dict[str, Any]]) -> str:
def _to_relative_chart_path(chart_path: str, report_dir: str = "") -> str:
"""
将图表绝对路径转换为相对于报告文件的路径。
例如:
chart_path = "analysis_output/run_xxx/charts/bar.png"
report_dir = "analysis_output/run_xxx"
"./charts/bar.png"
如果无法计算相对路径,则只保留 ./charts/filename.png
"""
if not chart_path:
return chart_path
# 统一为正斜杠
chart_path = chart_path.replace('\\', '/')
if report_dir:
report_dir = report_dir.replace('\\', '/')
try:
rel = os.path.relpath(chart_path, report_dir).replace('\\', '/')
return './' + rel if not rel.startswith('.') else rel
except ValueError:
pass
# Fallback: 提取 charts/filename 部分
parts = chart_path.replace('\\', '/').split('/')
if 'charts' in parts:
idx = parts.index('charts')
return './' + '/'.join(parts[idx:])
# 最后兜底:直接用文件名
return './charts/' + os.path.basename(chart_path)
def _convert_chart_paths_in_report(report: str, report_dir: str = "") -> str:
"""
将报告中所有 ![xxx](绝对路径) 的图表路径转换为相对路径。
同时统一反斜杠为正斜杠。
"""
def replace_img(match):
alt = match.group(1)
path = match.group(2)
rel_path = _to_relative_chart_path(path, report_dir)
return f'![{alt}]({rel_path})'
# 匹配 ![任意文字](路径)
return re.sub(r'!\[([^\]]*)\]\(([^)]+)\)', replace_img, report)
def generate_report( def generate_report(
results: List[AnalysisResult], results: List[AnalysisResult],
requirement: RequirementSpec, requirement: RequirementSpec,
@@ -339,14 +392,19 @@ def generate_report(
structure = organize_report_structure(key_findings, requirement, data_profile) structure = organize_report_structure(key_findings, requirement, data_profile)
# 尝试使用AI生成报告 # 尝试使用AI生成报告
api_key = os.getenv('OPENAI_API_KEY') from src.config import get_config
config = get_config()
api_key = config.llm.api_key
if api_key: if api_key:
try: try:
from openai import OpenAI from openai import OpenAI
client = OpenAI(api_key=api_key) client = OpenAI(
api_key=api_key,
base_url=config.llm.base_url
)
report = _generate_report_with_ai( report = _generate_report_with_ai(
client, results, key_findings, structure, requirement, data_profile client, config, results, key_findings, structure, requirement, data_profile
) )
except Exception as e: except Exception as e:
# Fallback to rule-based generation # Fallback to rule-based generation
@@ -359,16 +417,21 @@ def generate_report(
results, key_findings, structure, requirement, data_profile results, key_findings, structure, requirement, data_profile
) )
# 保存报告 # 保存报告(仅当 output_path 指向文件时)
if output_path: if output_path and not os.path.isdir(output_path):
with open(output_path, 'w', encoding='utf-8') as f: with open(output_path, 'w', encoding='utf-8') as f:
f.write(report) f.write(report)
# 将图表路径转换为相对于报告所在目录的路径
report_dir = output_path if output_path and os.path.isdir(output_path) else ""
report = _convert_chart_paths_in_report(report, report_dir)
return report return report
def _generate_report_with_ai( def _generate_report_with_ai(
client, client,
config,
results: List[AnalysisResult], results: List[AnalysisResult],
key_findings: List[Dict[str, Any]], key_findings: List[Dict[str, Any]],
structure: Dict[str, Any], structure: Dict[str, Any],
@@ -377,6 +440,29 @@ def _generate_report_with_ai(
) -> str: ) -> str:
"""使用AI生成报告。""" """使用AI生成报告。"""
# 构建分析数据摘要从results中提取实际数据
data_summaries = []
all_chart_paths = []
for r in results:
if r.success and r.data:
data_str = json.dumps(r.data, ensure_ascii=False, default=str)[:1500]
data_summaries.append(f"### {r.task_name}\n{data_str}")
# 收集所有图表路径
for viz in (r.visualizations or []):
if viz:
all_chart_paths.append(_to_relative_chart_path(viz))
if isinstance(r.data, dict):
for key, val in r.data.items():
if isinstance(val, dict) and val.get('chart_path'):
all_chart_paths.append(_to_relative_chart_path(val['chart_path']))
data_section = "\n\n".join(data_summaries) if data_summaries else "无详细数据"
# 图表路径列表
charts_section = ""
if all_chart_paths:
charts_section = "\n可用图表文件(请在报告中嵌入):\n" + "\n".join(f"- {p}" for p in all_chart_paths)
# 构建提示 # 构建提示
prompt = f"""你是一位专业的数据分析师,需要根据分析结果生成一份完整的分析报告。 prompt = f"""你是一位专业的数据分析师,需要根据分析结果生成一份完整的分析报告。
@@ -386,40 +472,44 @@ def _generate_report_with_ai(
- 列数:{data_profile.column_count} - 列数:{data_profile.column_count}
- 质量分数:{data_profile.quality_score}/100 - 质量分数:{data_profile.quality_score}/100
关键字段:
{chr(10).join(f"- {k}: {v}" for k, v in data_profile.key_fields.items())}
用户需求: 用户需求:
{requirement.user_input} {requirement.user_input}
分析目标: 分析目标:
{chr(10).join(f"- {obj.name}: {obj.description}" for obj in requirement.objectives)} {chr(10).join(f"- {obj.name}: {obj.description}" for obj in requirement.objectives)}
分析结果数据:
{data_section}
关键发现(按重要性排序): 关键发现(按重要性排序):
{chr(10).join(f"{i+1}. [{f['category']}] {f['finding']}" for i, f in enumerate(key_findings[:10]))} {chr(10).join(f"{i+1}. [{f['category']}] {f['finding']}" for i, f in enumerate(key_findings[:10]))}
已完成的分析任务: 已完成的分析任务:
{chr(10).join(f"- {r.task_name}: {'成功' if r.success else '失败'}" for r in results)} {chr(10).join(f"- {r.task_name}: {'成功' if r.success else '失败'}, 洞察: {'; '.join(r.insights[:3])}" for r in results)}
{charts_section}
跳过的分析 请生成一份专业的Markdown分析报告包含
{chr(10).join(f"- {r.task_name}: {r.error}" for r in results if not r.success)}
请生成一份专业的分析报告,包含以下部分: 1. **执行摘要**3-5个关键发现用数据说话
2. **数据概览**(数据集基本信息)
1. 执行摘要3-5个关键发现 3. **详细分析**(按主题组织,引用具体数据和数字
2. 数据概览 4. **结论与建议**(可操作的建议,说明依据)
3. 详细分析(按主题组织)
4. 结论与建议
要求: 要求:
- 使用Markdown格式 - 使用Markdown格式
- 突出异常和趋势 - 突出异常和趋势,引用具体数字
- 提供可操作的建议 - 提供可操作的建议
- 说明建议的依据
- 如果有分析被跳过,说明原因
- 使用清晰的结构和标题 - 使用清晰的结构和标题
- 用中文撰写
- 重要:在报告中嵌入图表,使用 ![描述](图表路径) 格式。将图表放在相关分析段落旁边,让报告图文结合。每个图表都要嵌入,不要遗漏。
""" """
try: try:
response = client.chat.completions.create( response = client.chat.completions.create(
model="gpt-4", model=config.llm.model,
messages=[ messages=[
{"role": "system", "content": "你是一位专业的数据分析师,擅长从数据中提炼洞察并撰写清晰的分析报告。"}, {"role": "system", "content": "你是一位专业的数据分析师,擅长从数据中提炼洞察并撰写清晰的分析报告。"},
{"role": "user", "content": prompt} {"role": "user", "content": prompt}

View File

@@ -6,6 +6,7 @@ from openai import OpenAI
from src.models.requirement_spec import RequirementSpec, AnalysisObjective from src.models.requirement_spec import RequirementSpec, AnalysisObjective
from src.models.data_profile import DataProfile from src.models.data_profile import DataProfile
from src.config import get_config
def understand_requirement( def understand_requirement(
@@ -29,13 +30,14 @@ def understand_requirement(
Requirements: FR-2.1, FR-2.2 Requirements: FR-2.1, FR-2.2
""" """
# Get API key from environment # Get config
api_key = os.getenv('OPENAI_API_KEY') config = get_config()
api_key = config.llm.api_key
if not api_key: if not api_key:
# Fallback to rule-based analysis if no API key # Fallback to rule-based analysis if no API key
return _fallback_requirement_understanding(user_input, data_profile, template_path) return _fallback_requirement_understanding(user_input, data_profile, template_path)
client = OpenAI(api_key=api_key) client = OpenAI(api_key=api_key, base_url=config.llm.base_url)
# Build prompt for AI # Build prompt for AI
prompt = _build_requirement_prompt(user_input, data_profile, template_path) prompt = _build_requirement_prompt(user_input, data_profile, template_path)
@@ -43,7 +45,7 @@ def understand_requirement(
try: try:
# Call LLM # Call LLM
response = client.chat.completions.create( response = client.chat.completions.create(
model="gpt-4", model=config.llm.model,
messages=[ messages=[
{"role": "system", "content": "You are a data analysis expert who understands user requirements and converts them into concrete analysis objectives."}, {"role": "system", "content": "You are a data analysis expert who understands user requirements and converts them into concrete analysis objectives."},
{"role": "user", "content": prompt} {"role": "user", "content": prompt}

View File

@@ -1,9 +1,9 @@
"""Task execution engine using ReAct pattern.""" """Task execution engine using ReAct pattern — fully AI-driven."""
import os
import json import json
import re import re
import time import time
import logging
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
from openai import OpenAI from openai import OpenAI
@@ -11,6 +11,9 @@ from src.models.analysis_plan import AnalysisTask
from src.models.analysis_result import AnalysisResult from src.models.analysis_result import AnalysisResult
from src.tools.base import AnalysisTool from src.tools.base import AnalysisTool
from src.data_access import DataAccessLayer from src.data_access import DataAccessLayer
from src.config import get_config
logger = logging.getLogger(__name__)
def execute_task( def execute_task(
@@ -21,57 +24,42 @@ def execute_task(
) -> AnalysisResult: ) -> AnalysisResult:
""" """
Execute analysis task using ReAct pattern. Execute analysis task using ReAct pattern.
AI decides which tools to call and with what parameters.
ReAct loop: Thought -> Action -> Observation -> repeat No hardcoded heuristics — everything is AI-driven.
Args:
task: Analysis task to execute
tools: Available analysis tools
data_access: Data access layer for executing tools
max_iterations: Maximum number of iterations
Returns:
AnalysisResult with execution results
Requirements: FR-5.1
""" """
start_time = time.time() start_time = time.time()
config = get_config()
api_key = config.llm.api_key
# Get API key
api_key = os.getenv('OPENAI_API_KEY')
if not api_key: if not api_key:
# Fallback to simple execution
return _fallback_task_execution(task, tools, data_access) return _fallback_task_execution(task, tools, data_access)
client = OpenAI(api_key=api_key) client = OpenAI(api_key=api_key, base_url=config.llm.base_url)
# Execution history
history = [] history = []
visualizations = [] visualizations = []
column_names = data_access.columns
try: try:
for iteration in range(max_iterations): for iteration in range(max_iterations):
# Thought: AI decides next action prompt = _build_thought_prompt(task, tools, history, column_names)
thought_prompt = _build_thought_prompt(task, tools, history)
thought_response = client.chat.completions.create( response = client.chat.completions.create(
model="gpt-4", model=config.llm.model,
messages=[ messages=[
{"role": "system", "content": "You are a data analyst executing analysis tasks. Use the ReAct pattern: think, act, observe."}, {"role": "system", "content": _system_prompt()},
{"role": "user", "content": thought_prompt} {"role": "user", "content": prompt}
], ],
temperature=0.7, temperature=0.3,
max_tokens=1000 max_tokens=1200
) )
thought = _parse_thought_response(thought_response.choices[0].message.content) thought = _parse_thought_response(response.choices[0].message.content)
history.append({"type": "thought", "content": thought}) history.append({"type": "thought", "content": thought})
# Check if task is complete
if thought.get('is_completed', False): if thought.get('is_completed', False):
break break
# Action: Execute selected tool
tool_name = thought.get('selected_tool') tool_name = thought.get('selected_tool')
tool_params = thought.get('tool_params', {}) tool_params = thought.get('tool_params', {})
@@ -84,95 +72,126 @@ def execute_task(
"tool": tool_name, "tool": tool_name,
"params": tool_params "params": tool_params
}) })
# Observation: Record result
history.append({ history.append({
"type": "observation", "type": "observation",
"result": action_result "result": action_result
}) })
if isinstance(action_result, dict) and 'visualization_path' in action_result:
# Track visualizations
if 'visualization_path' in action_result:
visualizations.append(action_result['visualization_path']) visualizations.append(action_result['visualization_path'])
if isinstance(action_result, dict) and action_result.get('data', {}).get('chart_path'):
visualizations.append(action_result['data']['chart_path'])
else:
history.append({
"type": "observation",
"result": {"error": f"Tool '{tool_name}' not found. Available: {[t.name for t in tools]}"}
})
# Extract insights from history
insights = extract_insights(history, client) insights = extract_insights(history, client)
execution_time = time.time() - start_time execution_time = time.time() - start_time
# Collect all observation data
all_data = {}
for entry in history:
if entry['type'] == 'observation':
result = entry.get('result', {})
if isinstance(result, dict) and result.get('success', True):
all_data[f"step_{len(all_data)}"] = result
return AnalysisResult( return AnalysisResult(
task_id=task.id, task_id=task.id,
task_name=task.name, task_name=task.name,
success=True, success=True,
data=history[-1].get('result', {}) if history else {}, data=all_data,
visualizations=visualizations, visualizations=visualizations,
insights=insights, insights=insights,
execution_time=execution_time execution_time=execution_time
) )
except Exception as e: except Exception as e:
execution_time = time.time() - start_time logger.error(f"Task execution failed: {e}")
return AnalysisResult( return AnalysisResult(
task_id=task.id, task_id=task.id,
task_name=task.name, task_name=task.name,
success=False, success=False,
error=str(e), error=str(e),
execution_time=execution_time execution_time=time.time() - start_time
) )
def _system_prompt() -> str:
return (
"You are a data analyst executing analysis tasks by calling tools. "
"You can ONLY see column names and tool descriptions — never raw data rows. "
"You MUST call tools to get any data. Always respond with valid JSON. "
"Use actual column names. Pick the right tool and parameters for the task."
)
def _build_thought_prompt( def _build_thought_prompt(
task: AnalysisTask, task: AnalysisTask,
tools: List[AnalysisTool], tools: List[AnalysisTool],
history: List[Dict[str, Any]] history: List[Dict[str, Any]],
column_names: List[str] = None
) -> str: ) -> str:
"""Build prompt for thought step.""" """Build prompt for the ReAct thought step."""
tool_descriptions = "\n".join([ tool_descriptions = "\n".join([
f"- {tool.name}: {tool.description}" f"- {tool.name}: {tool.description}\n Parameters: {json.dumps(tool.parameters.get('properties', {}), ensure_ascii=False)}"
for tool in tools for tool in tools
]) ])
history_str = "\n".join([ columns_str = f"\nAvailable Data Columns: {', '.join(column_names)}\n" if column_names else ""
f"{i+1}. {h['type']}: {str(h.get('content', h.get('result', '')))[:200]}"
for i, h in enumerate(history[-5:]) # Last 5 steps
])
prompt = f"""Task: {task.description} history_str = ""
if history:
for h in history[-8:]:
if h['type'] == 'thought':
content = h.get('content', {})
history_str += f"\nThought: {content.get('reasoning', '')[:200]}"
elif h['type'] == 'action':
history_str += f"\nAction: {h.get('tool', '')}({json.dumps(h.get('params', {}), ensure_ascii=False)})"
elif h['type'] == 'observation':
result = h.get('result', {})
result_str = json.dumps(result, ensure_ascii=False, default=str)[:500]
history_str += f"\nObservation: {result_str}"
actions_taken = sum(1 for h in history if h['type'] == 'action')
return f"""Task: {task.description}
Expected Output: {task.expected_output} Expected Output: {task.expected_output}
{columns_str}
Available Tools: Available Tools:
{tool_descriptions} {tool_descriptions}
Execution History: Execution History:{history_str if history_str else " (none yet — start by calling a tool)"}
{history_str if history else "No history yet"}
Think about: Actions taken: {actions_taken}
1. What is the current state?
2. What should I do next?
3. Which tool should I use?
4. Is the task completed?
Respond in JSON format: Instructions:
1. Pick the most relevant tool and call it with correct column names.
2. After each observation, decide if you need more data or can conclude.
3. Aim for 2-4 tool calls total to gather enough data.
4. IMPORTANT: For key findings, also generate visualizations (charts) using create_bar_chart, create_pie_chart, create_line_chart, or create_heatmap. The report needs charts embedded — text-only results are not enough.
5. When you have enough data AND have generated at least one chart, set is_completed=true and summarize findings in reasoning.
Respond ONLY with this JSON (no other text):
{{ {{
"reasoning": "Your reasoning", "reasoning": "your analysis reasoning",
"is_completed": false, "is_completed": false,
"selected_tool": "tool_name", "selected_tool": "tool_name",
"tool_params": {{"param": "value"}} "tool_params": {{"param": "value"}}
}} }}
""" """
return prompt
def _parse_thought_response(response_text: str) -> Dict[str, Any]: def _parse_thought_response(response_text: str) -> Dict[str, Any]:
"""Parse thought response from AI.""" """Parse AI thought response JSON."""
json_match = re.search(r'\{.*\}', response_text, re.DOTALL) json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
if json_match: if json_match:
try: try:
return json.loads(json_match.group()) return json.loads(json_match.group())
except json.JSONDecodeError: except json.JSONDecodeError:
pass pass
return { return {
'reasoning': response_text, 'reasoning': response_text,
'is_completed': False, 'is_completed': False,
@@ -186,80 +205,78 @@ def call_tool(
data_access: DataAccessLayer, data_access: DataAccessLayer,
**kwargs **kwargs
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """Call an analysis tool and return the result."""
Call analysis tool and return result.
Args:
tool: Tool to execute
data_access: Data access layer
**kwargs: Tool parameters
Returns:
Tool execution result
Requirements: FR-5.2
"""
try: try:
result = data_access.execute_tool(tool, **kwargs) result = data_access.execute_tool(tool, **kwargs)
return { return {'success': True, 'data': result}
'success': True,
'data': result
}
except Exception as e: except Exception as e:
return { return {'success': False, 'error': str(e)}
'success': False,
'error': str(e)
}
def extract_insights( def extract_insights(
history: List[Dict[str, Any]], history: List[Dict[str, Any]],
client: Optional[OpenAI] = None client: Optional[OpenAI] = None
) -> List[str]: ) -> List[str]:
""" """Extract insights from execution history using AI."""
Extract insights from execution history.
Args:
history: Execution history
client: OpenAI client (optional)
Returns:
List of insights
Requirements: FR-5.4
"""
if not client: if not client:
# Simple extraction without AI return _extract_insights_from_observations(history)
insights = []
for entry in history:
if entry['type'] == 'observation':
result = entry.get('result', {})
if isinstance(result, dict) and 'data' in result:
insights.append(f"Found data: {str(result['data'])[:100]}")
return insights[:5] # Limit to 5
# AI-driven insight extraction config = get_config()
history_str = json.dumps(history, indent=2, ensure_ascii=False)[:3000] history_str = json.dumps(history, indent=2, ensure_ascii=False, default=str)[:4000]
try: try:
response = client.chat.completions.create( response = client.chat.completions.create(
model="gpt-4", model=config.llm.model,
messages=[ messages=[
{"role": "system", "content": "Extract key insights from analysis execution history."}, {"role": "system", "content": "You are a data analyst. Extract key insights from analysis results. Respond in Chinese. Return a JSON array of 3-5 insight strings with specific numbers."},
{"role": "user", "content": f"Execution history:\n{history_str}\n\nExtract 3-5 key insights as a JSON array of strings."} {"role": "user", "content": f"Execution history:\n{history_str}\n\nExtract 3-5 key data-driven insights as a JSON array of strings."}
], ],
temperature=0.7, temperature=0.5,
max_tokens=500 max_tokens=800
) )
text = response.choices[0].message.content
insights_text = response.choices[0].message.content json_match = re.search(r'\[.*\]', text, re.DOTALL)
json_match = re.search(r'\[.*\]', insights_text, re.DOTALL)
if json_match: if json_match:
return json.loads(json_match.group()) parsed = json.loads(json_match.group())
except: if isinstance(parsed, list) and len(parsed) > 0:
pass return parsed
except Exception as e:
logger.warning(f"AI insight extraction failed: {e}")
return ["Analysis completed successfully"] return _extract_insights_from_observations(history)
def _extract_insights_from_observations(history: List[Dict[str, Any]]) -> List[str]:
"""Fallback: extract insights directly from observation data."""
insights = []
for entry in history:
if entry['type'] != 'observation':
continue
result = entry.get('result', {})
if not isinstance(result, dict):
continue
data = result.get('data', result)
if not isinstance(data, dict):
continue
if 'groups' in data:
top = data['groups'][:3] if isinstance(data['groups'], list) else []
if top:
group_str = ', '.join(f"{g.get('group','?')}: {g.get('value',0)}" for g in top)
insights.append(f"Top groups: {group_str}")
if 'distribution' in data:
dist = data['distribution'][:3] if isinstance(data['distribution'], list) else []
if dist:
dist_str = ', '.join(f"{d.get('value','?')}: {d.get('percentage',0):.1f}%" for d in dist)
insights.append(f"Distribution: {dist_str}")
if 'trend' in data:
insights.append(f"Trend: {data['trend']}, growth rate: {data.get('growth_rate', 'N/A')}")
if 'outlier_count' in data:
insights.append(f"Outliers: {data['outlier_count']} ({data.get('outlier_percentage', 0):.1f}%)")
if 'mean' in data and 'column' in data:
insights.append(f"{data['column']}: mean={data['mean']:.2f}, median={data.get('median', 'N/A')}")
return insights[:5] if insights else ["Analysis completed"]
def _find_tool(tools: List[AnalysisTool], tool_name: str) -> Optional[AnalysisTool]: def _find_tool(tools: List[AnalysisTool], tool_name: str) -> Optional[AnalysisTool]:
@@ -275,42 +292,53 @@ def _fallback_task_execution(
tools: List[AnalysisTool], tools: List[AnalysisTool],
data_access: DataAccessLayer data_access: DataAccessLayer
) -> AnalysisResult: ) -> AnalysisResult:
"""Simple fallback execution without AI.""" """Fallback execution without AI — runs required tools with minimal params."""
start_time = time.time() start_time = time.time()
all_data = {}
insights = []
try: try:
# Execute first applicable tool columns = data_access.columns
for tool_name in task.required_tools: tools_to_run = task.required_tools if task.required_tools else [t.name for t in tools[:3]]
for tool_name in tools_to_run:
tool = _find_tool(tools, tool_name) tool = _find_tool(tools, tool_name)
if tool: if not tool:
result = call_tool(tool, data_access) continue
execution_time = time.time() - start_time # Try calling with first column as a basic param
params = _guess_minimal_params(tool, columns)
if params:
result = call_tool(tool, data_access, **params)
if result.get('success'):
all_data[tool_name] = result.get('data', {})
return AnalysisResult( return AnalysisResult(
task_id=task.id, task_id=task.id,
task_name=task.name, task_name=task.name,
success=result.get('success', False), success=True,
data=result.get('data', {}), data=all_data,
insights=[f"Executed {tool_name}"], insights=insights or ["Fallback execution completed"],
execution_time=execution_time execution_time=time.time() - start_time
) )
# No tools executed
execution_time = time.time() - start_time
return AnalysisResult(
task_id=task.id,
task_name=task.name,
success=False,
error="No applicable tools found",
execution_time=execution_time
)
except Exception as e: except Exception as e:
execution_time = time.time() - start_time
return AnalysisResult( return AnalysisResult(
task_id=task.id, task_id=task.id,
task_name=task.name, task_name=task.name,
success=False, success=False,
error=str(e), error=str(e),
execution_time=execution_time execution_time=time.time() - start_time
) )
def _guess_minimal_params(tool: AnalysisTool, columns: List[str]) -> Optional[Dict[str, Any]]:
"""Guess minimal params for fallback — just pick first applicable column."""
props = tool.parameters.get('properties', {})
required = tool.parameters.get('required', [])
params = {}
for param_name in required:
prop = props.get(param_name, {})
if prop.get('type') == 'string' and 'column' in param_name.lower():
params[param_name] = columns[0] if columns else ''
elif prop.get('type') == 'string':
params[param_name] = columns[0] if columns else ''
return params if params else None

View File

@@ -10,15 +10,15 @@ from src.env_loader import load_env_with_fallback
from src.data_access import DataAccessLayer from src.data_access import DataAccessLayer
from src.models import DataProfile, RequirementSpec, AnalysisPlan, AnalysisResult from src.models import DataProfile, RequirementSpec, AnalysisPlan, AnalysisResult
from src.engines import ( from src.engines import (
understand_data,
understand_requirement, understand_requirement,
plan_analysis, plan_analysis,
execute_task, execute_task,
adjust_plan, adjust_plan,
generate_report generate_report
) )
from src.engines.ai_data_understanding import ai_understand_data_with_dal
from src.tools.tool_manager import ToolManager from src.tools.tool_manager import ToolManager
from src.tools.base import ToolRegistry from src.tools.base import _global_registry
from src.error_handling import execute_task_with_recovery from src.error_handling import execute_task_with_recovery
from src.logging_config import ( from src.logging_config import (
log_stage_start, log_stage_start,
@@ -81,7 +81,7 @@ class AnalysisOrchestrator:
# 初始化组件 # 初始化组件
self.data_access: Optional[DataAccessLayer] = None self.data_access: Optional[DataAccessLayer] = None
self.tool_manager = ToolManager(ToolRegistry()) self.tool_manager = ToolManager()
# 阶段结果 # 阶段结果
self.data_profile: Optional[DataProfile] = None self.data_profile: Optional[DataProfile] = None
@@ -211,7 +211,7 @@ class AnalysisOrchestrator:
def _stage_data_understanding(self) -> DataProfile: def _stage_data_understanding(self) -> DataProfile:
""" """
阶段1数据理解 阶段1数据理解AI驱动
返回: 返回:
数据画像 数据画像
@@ -219,15 +219,14 @@ class AnalysisOrchestrator:
log_stage_start(logger, "数据理解") log_stage_start(logger, "数据理解")
stage_start = time.time() stage_start = time.time()
# 加载数据 # 使用 AI 驱动的数据理解,同时获取 DAL 避免重复加载
logger.info(f"加载数据文件: {self.data_file}") logger.info(f"加载数据文件: {self.data_file}")
self.data_access = DataAccessLayer.load_from_file(self.data_file) data_profile, self.data_access = ai_understand_data_with_dal(self.data_file)
logger.info(f"✓ 数据加载成功: {self.data_access.shape[0]} 行, {self.data_access.shape[1]}")
# 理解数据 # 设置输出目录,确保图表等文件保存到正确位置
logger.info("分析数据特征...") self.data_access.set_output_dir(str(self.output_dir))
data_profile = understand_data(self.data_access)
logger.info(f"✓ 数据加载成功: {data_profile.row_count} 行, {data_profile.column_count}")
logger.info(f"✓ 数据类型: {data_profile.inferred_type}") logger.info(f"✓ 数据类型: {data_profile.inferred_type}")
logger.info(f"✓ 数据质量分数: {data_profile.quality_score:.1f}/100") logger.info(f"✓ 数据质量分数: {data_profile.quality_score:.1f}/100")
logger.info(f"✓ 关键字段: {list(data_profile.key_fields.keys())}") logger.info(f"✓ 关键字段: {list(data_profile.key_fields.keys())}")
@@ -271,11 +270,15 @@ class AnalysisOrchestrator:
""" """
log_stage_start(logger, "分析规划") log_stage_start(logger, "分析规划")
# 生成分析计划 # 选择工具(提前选好,传给 planner
tools = self.tool_manager.select_tools(self.data_profile)
# 生成分析计划(传入可用工具,让 AI 生成 tool-aware 的任务)
logger.info("生成分析计划...") logger.info("生成分析计划...")
analysis_plan = plan_analysis( analysis_plan = plan_analysis(
data_profile=self.data_profile, data_profile=self.data_profile,
requirement=self.requirement_spec requirement=self.requirement_spec,
available_tools=tools
) )
logger.info(f"✓ 生成任务数量: {len(analysis_plan.tasks)}") logger.info(f"✓ 生成任务数量: {len(analysis_plan.tasks)}")

View File

@@ -1,301 +0,0 @@
"""数据查询工具。"""
import pandas as pd
import numpy as np
from typing import Dict, Any
from src.tools.base import AnalysisTool
from src.models import DataProfile
class GetColumnDistributionTool(AnalysisTool):
"""获取列的分布统计工具。"""
@property
def name(self) -> str:
return "get_column_distribution"
@property
def description(self) -> str:
return "获取指定列的分布统计信息,包括值计数、百分比等。适用于分类和数值列。"
@property
def parameters(self) -> Dict[str, Any]:
return {
"type": "object",
"properties": {
"column": {
"type": "string",
"description": "要分析的列名"
},
"top_n": {
"type": "integer",
"description": "返回前N个最常见的值",
"default": 10
}
},
"required": ["column"]
}
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
"""执行列分布分析。"""
column = kwargs.get('column')
top_n = kwargs.get('top_n', 10)
if column not in data.columns:
return {'error': f'{column} 不存在'}
col_data = data[column]
value_counts = col_data.value_counts().head(top_n)
total = len(col_data.dropna())
distribution = []
for value, count in value_counts.items():
distribution.append({
'value': str(value),
'count': int(count),
'percentage': float(count / total * 100) if total > 0 else 0.0
})
return {
'column': column,
'total_count': int(total),
'unique_count': int(col_data.nunique()),
'missing_count': int(col_data.isna().sum()),
'distribution': distribution
}
def is_applicable(self, data_profile: DataProfile) -> bool:
"""适用于所有数据。"""
return True
class GetValueCountsTool(AnalysisTool):
"""获取值计数工具。"""
@property
def name(self) -> str:
return "get_value_counts"
@property
def description(self) -> str:
return "获取指定列的值计数,返回每个唯一值的出现次数。"
@property
def parameters(self) -> Dict[str, Any]:
return {
"type": "object",
"properties": {
"column": {
"type": "string",
"description": "要分析的列名"
},
"normalize": {
"type": "boolean",
"description": "是否返回百分比而不是计数",
"default": False
}
},
"required": ["column"]
}
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
"""执行值计数。"""
column = kwargs.get('column')
normalize = kwargs.get('normalize', False)
if column not in data.columns:
return {'error': f'{column} 不存在'}
value_counts = data[column].value_counts(normalize=normalize)
result = {}
for value, count in value_counts.items():
result[str(value)] = float(count) if normalize else int(count)
return {
'column': column,
'value_counts': result,
'normalized': normalize
}
def is_applicable(self, data_profile: DataProfile) -> bool:
"""适用于所有数据。"""
return True
class GetTimeSeriesTool(AnalysisTool):
"""获取时间序列数据工具。"""
@property
def name(self) -> str:
return "get_time_series"
@property
def description(self) -> str:
return "获取时间序列数据,按时间聚合指定指标。"
@property
def parameters(self) -> Dict[str, Any]:
return {
"type": "object",
"properties": {
"time_column": {
"type": "string",
"description": "时间列名"
},
"value_column": {
"type": "string",
"description": "要聚合的值列名"
},
"aggregation": {
"type": "string",
"description": "聚合方式count, sum, mean, min, max",
"default": "count"
},
"frequency": {
"type": "string",
"description": "时间频率D(天), W(周), M(月), Y(年)",
"default": "D"
}
},
"required": ["time_column"]
}
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
"""执行时间序列分析。"""
time_column = kwargs.get('time_column')
value_column = kwargs.get('value_column')
aggregation = kwargs.get('aggregation', 'count')
frequency = kwargs.get('frequency', 'D')
if time_column not in data.columns:
return {'error': f'时间列 {time_column} 不存在'}
# 转换为日期时间类型
try:
time_data = pd.to_datetime(data[time_column])
except Exception as e:
return {'error': f'无法将 {time_column} 转换为日期时间: {str(e)}'}
# 创建临时 DataFrame
temp_df = pd.DataFrame({'time': time_data})
if value_column:
if value_column not in data.columns:
return {'error': f'值列 {value_column} 不存在'}
temp_df['value'] = data[value_column]
# 设置时间索引
temp_df.set_index('time', inplace=True)
# 按频率重采样
if value_column:
if aggregation == 'count':
result = temp_df.resample(frequency).count()
elif aggregation == 'sum':
result = temp_df.resample(frequency).sum()
elif aggregation == 'mean':
result = temp_df.resample(frequency).mean()
elif aggregation == 'min':
result = temp_df.resample(frequency).min()
elif aggregation == 'max':
result = temp_df.resample(frequency).max()
else:
return {'error': f'不支持的聚合方式: {aggregation}'}
else:
result = temp_df.resample(frequency).size().to_frame('count')
# 转换为字典
time_series = []
for timestamp, row in result.iterrows():
time_series.append({
'time': timestamp.strftime('%Y-%m-%d'),
'value': float(row.iloc[0]) if not pd.isna(row.iloc[0]) else 0.0
})
# 限制返回的数据点数量最多100个隐私保护要求
if len(time_series) > 100:
# 均匀采样以保持趋势
step = len(time_series) / 100
sampled_indices = [int(i * step) for i in range(100)]
time_series = [time_series[i] for i in sampled_indices]
return {
'time_column': time_column,
'value_column': value_column,
'aggregation': aggregation,
'frequency': frequency,
'time_series': time_series,
'total_points': len(result), # 记录原始数据点数量
'returned_points': len(time_series) # 记录返回的数据点数量
}
def is_applicable(self, data_profile: DataProfile) -> bool:
"""适用于包含日期时间列的数据。"""
return any(col.dtype == 'datetime' for col in data_profile.columns)
class GetCorrelationTool(AnalysisTool):
"""获取相关性分析工具。"""
@property
def name(self) -> str:
return "get_correlation"
@property
def description(self) -> str:
return "计算数值列之间的相关系数矩阵。"
@property
def parameters(self) -> Dict[str, Any]:
return {
"type": "object",
"properties": {
"columns": {
"type": "array",
"items": {"type": "string"},
"description": "要分析的列名列表,如果为空则分析所有数值列"
},
"method": {
"type": "string",
"description": "相关系数方法pearson, spearman, kendall",
"default": "pearson"
}
}
}
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
"""执行相关性分析。"""
columns = kwargs.get('columns', [])
method = kwargs.get('method', 'pearson')
# 如果没有指定列,使用所有数值列
if not columns:
numeric_cols = data.select_dtypes(include=[np.number]).columns.tolist()
else:
numeric_cols = [col for col in columns if col in data.columns]
if len(numeric_cols) < 2:
return {'error': '至少需要两个数值列来计算相关性'}
# 计算相关系数矩阵
corr_matrix = data[numeric_cols].corr(method=method)
# 转换为字典格式
correlation = {}
for col1 in corr_matrix.columns:
correlation[col1] = {}
for col2 in corr_matrix.columns:
correlation[col1][col2] = float(corr_matrix.loc[col1, col2])
return {
'columns': numeric_cols,
'method': method,
'correlation_matrix': correlation
}
def is_applicable(self, data_profile: DataProfile) -> bool:
"""适用于包含至少两个数值列的数据。"""
numeric_cols = [col for col in data_profile.columns if col.dtype == 'numeric']
return len(numeric_cols) >= 2

View File

@@ -1,325 +0,0 @@
"""统计分析工具。"""
import pandas as pd
import numpy as np
from typing import Dict, Any
from scipy import stats
from src.tools.base import AnalysisTool
from src.models import DataProfile
class CalculateStatisticsTool(AnalysisTool):
"""计算描述性统计工具。"""
@property
def name(self) -> str:
return "calculate_statistics"
@property
def description(self) -> str:
return "计算指定列的描述性统计信息,包括均值、中位数、标准差、最小值、最大值等。"
@property
def parameters(self) -> Dict[str, Any]:
return {
"type": "object",
"properties": {
"column": {
"type": "string",
"description": "要分析的列名"
}
},
"required": ["column"]
}
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
"""执行统计计算。"""
column = kwargs.get('column')
if column not in data.columns:
return {'error': f'{column} 不存在'}
col_data = data[column].dropna()
if not pd.api.types.is_numeric_dtype(col_data):
return {'error': f'{column} 不是数值类型'}
statistics = {
'column': column,
'count': int(len(col_data)),
'mean': float(col_data.mean()),
'median': float(col_data.median()),
'std': float(col_data.std()),
'min': float(col_data.min()),
'max': float(col_data.max()),
'q25': float(col_data.quantile(0.25)),
'q75': float(col_data.quantile(0.75)),
'skewness': float(col_data.skew()),
'kurtosis': float(col_data.kurtosis())
}
return statistics
def is_applicable(self, data_profile: DataProfile) -> bool:
"""适用于包含数值列的数据。"""
return any(col.dtype == 'numeric' for col in data_profile.columns)
class PerformGroupbyTool(AnalysisTool):
"""执行分组聚合工具。"""
@property
def name(self) -> str:
return "perform_groupby"
@property
def description(self) -> str:
return "按指定列分组,对另一列进行聚合计算(如求和、平均、计数等)。"
@property
def parameters(self) -> Dict[str, Any]:
return {
"type": "object",
"properties": {
"group_by": {
"type": "string",
"description": "分组依据的列名"
},
"value_column": {
"type": "string",
"description": "要聚合的值列名,如果为空则计数"
},
"aggregation": {
"type": "string",
"description": "聚合方式count, sum, mean, min, max, std",
"default": "count"
}
},
"required": ["group_by"]
}
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
"""执行分组聚合。"""
group_by = kwargs.get('group_by')
value_column = kwargs.get('value_column')
aggregation = kwargs.get('aggregation', 'count')
if group_by not in data.columns:
return {'error': f'分组列 {group_by} 不存在'}
if value_column and value_column not in data.columns:
return {'error': f'值列 {value_column} 不存在'}
# 执行分组聚合
if value_column:
grouped = data.groupby(group_by)[value_column]
else:
grouped = data.groupby(group_by).size()
aggregation = 'count'
if aggregation == 'count':
if value_column:
result = grouped.count()
else:
result = grouped
elif aggregation == 'sum':
result = grouped.sum()
elif aggregation == 'mean':
result = grouped.mean()
elif aggregation == 'min':
result = grouped.min()
elif aggregation == 'max':
result = grouped.max()
elif aggregation == 'std':
result = grouped.std()
else:
return {'error': f'不支持的聚合方式: {aggregation}'}
# 转换为字典
groups = []
for group_value, agg_value in result.items():
groups.append({
'group': str(group_value),
'value': float(agg_value) if not pd.isna(agg_value) else 0.0
})
# 限制返回的分组数量最多100个隐私保护要求
total_groups = len(groups)
if len(groups) > 100:
# 按值排序并取前100个
groups = sorted(groups, key=lambda x: x['value'], reverse=True)[:100]
return {
'group_by': group_by,
'value_column': value_column,
'aggregation': aggregation,
'groups': groups,
'total_groups': total_groups, # 记录原始分组数量
'returned_groups': len(groups) # 记录返回的分组数量
}
def is_applicable(self, data_profile: DataProfile) -> bool:
"""适用于所有数据。"""
return True
class DetectOutliersTool(AnalysisTool):
"""检测异常值工具。"""
@property
def name(self) -> str:
return "detect_outliers"
@property
def description(self) -> str:
return "使用IQR方法或Z-score方法检测数值列中的异常值。"
@property
def parameters(self) -> Dict[str, Any]:
return {
"type": "object",
"properties": {
"column": {
"type": "string",
"description": "要检测的列名"
},
"method": {
"type": "string",
"description": "检测方法iqr 或 zscore",
"default": "iqr"
},
"threshold": {
"type": "number",
"description": "阈值IQR倍数或Z-score标准差倍数",
"default": 1.5
}
},
"required": ["column"]
}
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
"""执行异常值检测。"""
column = kwargs.get('column')
method = kwargs.get('method', 'iqr')
threshold = kwargs.get('threshold', 1.5)
if column not in data.columns:
return {'error': f'{column} 不存在'}
col_data = data[column].dropna()
if not pd.api.types.is_numeric_dtype(col_data):
return {'error': f'{column} 不是数值类型'}
if method == 'iqr':
# IQR 方法
q1 = col_data.quantile(0.25)
q3 = col_data.quantile(0.75)
iqr = q3 - q1
lower_bound = q1 - threshold * iqr
upper_bound = q3 + threshold * iqr
outliers = col_data[(col_data < lower_bound) | (col_data > upper_bound)]
elif method == 'zscore':
# Z-score 方法
z_scores = np.abs(stats.zscore(col_data))
outliers = col_data[z_scores > threshold]
else:
return {'error': f'不支持的检测方法: {method}'}
return {
'column': column,
'method': method,
'threshold': threshold,
'outlier_count': int(len(outliers)),
'outlier_percentage': float(len(outliers) / len(col_data) * 100),
'outlier_values': outliers.head(20).tolist(), # 最多返回20个异常值
'bounds': {
'lower': float(lower_bound) if method == 'iqr' else None,
'upper': float(upper_bound) if method == 'iqr' else None
} if method == 'iqr' else None
}
def is_applicable(self, data_profile: DataProfile) -> bool:
"""适用于包含数值列的数据。"""
return any(col.dtype == 'numeric' for col in data_profile.columns)
class CalculateTrendTool(AnalysisTool):
"""计算趋势工具。"""
@property
def name(self) -> str:
return "calculate_trend"
@property
def description(self) -> str:
return "计算时间序列数据的趋势,包括线性回归斜率、增长率等。"
@property
def parameters(self) -> Dict[str, Any]:
return {
"type": "object",
"properties": {
"time_column": {
"type": "string",
"description": "时间列名"
},
"value_column": {
"type": "string",
"description": "值列名"
}
},
"required": ["time_column", "value_column"]
}
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
"""执行趋势计算。"""
time_column = kwargs.get('time_column')
value_column = kwargs.get('value_column')
if time_column not in data.columns:
return {'error': f'时间列 {time_column} 不存在'}
if value_column not in data.columns:
return {'error': f'值列 {value_column} 不存在'}
# 转换时间列
try:
time_data = pd.to_datetime(data[time_column])
except Exception as e:
return {'error': f'无法将 {time_column} 转换为日期时间: {str(e)}'}
# 创建数值型时间索引(天数)
time_numeric = (time_data - time_data.min()).dt.days.values
value_data = data[value_column].dropna().values
if len(value_data) < 2:
return {'error': '数据点太少,无法计算趋势'}
# 线性回归
slope, intercept, r_value, p_value, std_err = stats.linregress(
time_numeric[:len(value_data)], value_data
)
# 计算增长率
first_value = value_data[0]
last_value = value_data[-1]
growth_rate = ((last_value - first_value) / first_value * 100) if first_value != 0 else 0
return {
'time_column': time_column,
'value_column': value_column,
'slope': float(slope),
'intercept': float(intercept),
'r_squared': float(r_value ** 2),
'p_value': float(p_value),
'growth_rate': float(growth_rate),
'trend': 'increasing' if slope > 0 else 'decreasing' if slope < 0 else 'stable'
}
def is_applicable(self, data_profile: DataProfile) -> bool:
"""适用于包含日期时间列和数值列的数据。"""
has_datetime = any(col.dtype == 'datetime' for col in data_profile.columns)
has_numeric = any(col.dtype == 'numeric' for col in data_profile.columns)
return has_datetime and has_numeric

View File

@@ -1,182 +0,0 @@
"""工具管理器,负责根据数据特征动态选择和管理工具。"""
from typing import List, Dict, Any
import pandas as pd
from src.tools.base import AnalysisTool, ToolRegistry
from src.models import DataProfile
class ToolManager:
"""
工具管理器,负责根据数据特征动态选择合适的工具。
"""
def __init__(self, registry: ToolRegistry = None):
"""
初始化工具管理器。
参数:
registry: 工具注册表,如果为 None 则创建新的注册表
"""
self.registry = registry if registry else ToolRegistry()
self._missing_tools: List[str] = []
def select_tools(self, data_profile: DataProfile) -> List[AnalysisTool]:
"""
根据数据画像选择合适的工具。
参数:
data_profile: 数据画像
返回:
适用的工具列表
"""
selected_tools = []
# 检查时间字段
if self._has_datetime_column(data_profile):
selected_tools.extend(self._get_time_series_tools())
# 检查分类字段
if self._has_categorical_column(data_profile):
selected_tools.extend(self._get_categorical_tools())
# 检查数值字段
if self._has_numeric_column(data_profile):
selected_tools.extend(self._get_numeric_tools())
# 检查地理字段
if self._has_geo_column(data_profile):
selected_tools.extend(self._get_geo_tools())
# 添加通用工具(适用于所有数据)
selected_tools.extend(self._get_universal_tools())
# 去重
unique_tools = []
seen_names = set()
for tool in selected_tools:
if tool.name not in seen_names:
unique_tools.append(tool)
seen_names.add(tool.name)
return unique_tools
def _has_datetime_column(self, data_profile: DataProfile) -> bool:
"""检查是否包含日期时间列。"""
return any(col.dtype == 'datetime' for col in data_profile.columns)
def _has_categorical_column(self, data_profile: DataProfile) -> bool:
"""检查是否包含分类列。"""
return any(col.dtype == 'categorical' for col in data_profile.columns)
def _has_numeric_column(self, data_profile: DataProfile) -> bool:
"""检查是否包含数值列。"""
return any(col.dtype == 'numeric' for col in data_profile.columns)
def _has_geo_column(self, data_profile: DataProfile) -> bool:
"""检查是否包含地理列。"""
# 检查列名是否包含地理相关关键词
geo_keywords = ['lat', 'lon', 'latitude', 'longitude', 'location', 'address', 'city', 'country']
for col in data_profile.columns:
col_name_lower = col.name.lower()
if any(keyword in col_name_lower for keyword in geo_keywords):
return True
return False
def _get_time_series_tools(self) -> List[AnalysisTool]:
"""获取时间序列分析工具。"""
tools = []
tool_names = ['get_time_series', 'calculate_trend', 'create_line_chart']
for tool_name in tool_names:
try:
tool = self.registry.get_tool(tool_name)
tools.append(tool)
except KeyError:
self._missing_tools.append(tool_name)
return tools
def _get_categorical_tools(self) -> List[AnalysisTool]:
"""获取分类数据分析工具。"""
tools = []
tool_names = ['get_column_distribution', 'get_value_counts', 'perform_groupby',
'create_bar_chart', 'create_pie_chart']
for tool_name in tool_names:
try:
tool = self.registry.get_tool(tool_name)
tools.append(tool)
except KeyError:
self._missing_tools.append(tool_name)
return tools
def _get_numeric_tools(self) -> List[AnalysisTool]:
"""获取数值数据分析工具。"""
tools = []
tool_names = ['calculate_statistics', 'detect_outliers', 'get_correlation', 'create_heatmap']
for tool_name in tool_names:
try:
tool = self.registry.get_tool(tool_name)
tools.append(tool)
except KeyError:
self._missing_tools.append(tool_name)
return tools
def _get_geo_tools(self) -> List[AnalysisTool]:
"""获取地理数据分析工具。"""
tools = []
# 目前没有实现地理工具,记录为缺失
tool_names = ['create_map_visualization']
for tool_name in tool_names:
try:
tool = self.registry.get_tool(tool_name)
tools.append(tool)
except KeyError:
self._missing_tools.append(tool_name)
return tools
def _get_universal_tools(self) -> List[AnalysisTool]:
"""获取通用工具(适用于所有数据)。"""
tools = []
# 通用工具已经在其他类别中包含了
return tools
def get_missing_tools(self) -> List[str]:
"""
获取缺失的工具列表。
返回:
缺失的工具名称列表
"""
return list(set(self._missing_tools))
def clear_missing_tools(self) -> None:
"""清空缺失工具列表。"""
self._missing_tools = []
def get_tool_descriptions(self, tools: List[AnalysisTool]) -> List[Dict[str, Any]]:
"""
获取工具的描述信息(供 AI 选择)。
参数:
tools: 工具列表
返回:
工具描述列表
"""
descriptions = []
for tool in tools:
descriptions.append({
'name': tool.name,
'description': tool.description,
'parameters': tool.parameters
})
return descriptions

View File

@@ -1,443 +0,0 @@
"""可视化工具。"""
import pandas as pd
import numpy as np
import matplotlib
matplotlib.use('Agg') # 使用非交互式后端
import matplotlib.pyplot as plt
from typing import Dict, Any
import os
from pathlib import Path
from src.tools.base import AnalysisTool
from src.models import DataProfile
# 尝试导入 seaborn如果不可用则使用 matplotlib
try:
import seaborn as sns
HAS_SEABORN = True
except ImportError:
HAS_SEABORN = False
# 设置中文字体支持
plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
class CreateBarChartTool(AnalysisTool):
"""创建柱状图工具。"""
@property
def name(self) -> str:
return "create_bar_chart"
@property
def description(self) -> str:
return "创建柱状图,用于展示分类数据的分布或比较。"
@property
def parameters(self) -> Dict[str, Any]:
return {
"type": "object",
"properties": {
"x_column": {
"type": "string",
"description": "X轴列名分类变量"
},
"y_column": {
"type": "string",
"description": "Y轴列名数值变量如果为空则计数"
},
"title": {
"type": "string",
"description": "图表标题",
"default": "柱状图"
},
"output_path": {
"type": "string",
"description": "输出文件路径",
"default": "bar_chart.png"
},
"top_n": {
"type": "integer",
"description": "只显示前N个类别",
"default": 20
}
},
"required": ["x_column"]
}
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
"""执行柱状图生成。"""
x_column = kwargs.get('x_column')
y_column = kwargs.get('y_column')
title = kwargs.get('title', '柱状图')
output_path = kwargs.get('output_path', 'bar_chart.png')
top_n = kwargs.get('top_n', 20)
if x_column not in data.columns:
return {'error': f'{x_column} 不存在'}
if y_column and y_column not in data.columns:
return {'error': f'{y_column} 不存在'}
try:
# 准备数据
if y_column:
# 按 x_column 分组,对 y_column 求和
plot_data = data.groupby(x_column)[y_column].sum().sort_values(ascending=False).head(top_n)
else:
# 计数
plot_data = data[x_column].value_counts().head(top_n)
# 创建图表
fig, ax = plt.subplots(figsize=(12, 6))
plot_data.plot(kind='bar', ax=ax)
ax.set_title(title, fontsize=14, fontweight='bold')
ax.set_xlabel(x_column, fontsize=12)
ax.set_ylabel(y_column if y_column else '计数', fontsize=12)
ax.tick_params(axis='x', rotation=45)
plt.tight_layout()
# 确保输出目录存在
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
# 保存图表
plt.savefig(output_path, dpi=100, bbox_inches='tight')
plt.close(fig)
return {
'success': True,
'chart_path': output_path,
'chart_type': 'bar',
'data_points': len(plot_data),
'x_column': x_column,
'y_column': y_column
}
except Exception as e:
return {'error': f'生成柱状图失败: {str(e)}'}
def is_applicable(self, data_profile: DataProfile) -> bool:
"""适用于所有数据。"""
return True
class CreateLineChartTool(AnalysisTool):
"""创建折线图工具。"""
@property
def name(self) -> str:
return "create_line_chart"
@property
def description(self) -> str:
return "创建折线图,用于展示时间序列数据或趋势变化。"
@property
def parameters(self) -> Dict[str, Any]:
return {
"type": "object",
"properties": {
"x_column": {
"type": "string",
"description": "X轴列名通常是时间"
},
"y_column": {
"type": "string",
"description": "Y轴列名数值变量"
},
"title": {
"type": "string",
"description": "图表标题",
"default": "折线图"
},
"output_path": {
"type": "string",
"description": "输出文件路径",
"default": "line_chart.png"
}
},
"required": ["x_column", "y_column"]
}
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
"""执行折线图生成。"""
x_column = kwargs.get('x_column')
y_column = kwargs.get('y_column')
title = kwargs.get('title', '折线图')
output_path = kwargs.get('output_path', 'line_chart.png')
if x_column not in data.columns:
return {'error': f'{x_column} 不存在'}
if y_column not in data.columns:
return {'error': f'{y_column} 不存在'}
try:
# 准备数据
plot_data = data[[x_column, y_column]].copy()
plot_data = plot_data.sort_values(x_column)
# 如果数据点太多,采样
if len(plot_data) > 1000:
step = len(plot_data) // 1000
plot_data = plot_data.iloc[::step]
# 创建图表
fig, ax = plt.subplots(figsize=(12, 6))
ax.plot(plot_data[x_column], plot_data[y_column], marker='o', markersize=3, linewidth=2)
ax.set_title(title, fontsize=14, fontweight='bold')
ax.set_xlabel(x_column, fontsize=12)
ax.set_ylabel(y_column, fontsize=12)
ax.grid(True, alpha=0.3)
plt.xticks(rotation=45)
plt.tight_layout()
# 确保输出目录存在
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
# 保存图表
plt.savefig(output_path, dpi=100, bbox_inches='tight')
plt.close(fig)
return {
'success': True,
'chart_path': output_path,
'chart_type': 'line',
'data_points': len(plot_data),
'x_column': x_column,
'y_column': y_column
}
except Exception as e:
return {'error': f'生成折线图失败: {str(e)}'}
def is_applicable(self, data_profile: DataProfile) -> bool:
"""适用于包含数值列的数据。"""
return any(col.dtype == 'numeric' for col in data_profile.columns)
class CreatePieChartTool(AnalysisTool):
"""创建饼图工具。"""
@property
def name(self) -> str:
return "create_pie_chart"
@property
def description(self) -> str:
return "创建饼图,用于展示各部分占整体的比例。"
@property
def parameters(self) -> Dict[str, Any]:
return {
"type": "object",
"properties": {
"column": {
"type": "string",
"description": "要分析的列名"
},
"title": {
"type": "string",
"description": "图表标题",
"default": "饼图"
},
"output_path": {
"type": "string",
"description": "输出文件路径",
"default": "pie_chart.png"
},
"top_n": {
"type": "integer",
"description": "只显示前N个类别其余归为'其他'",
"default": 10
}
},
"required": ["column"]
}
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
"""执行饼图生成。"""
column = kwargs.get('column')
title = kwargs.get('title', '饼图')
output_path = kwargs.get('output_path', 'pie_chart.png')
top_n = kwargs.get('top_n', 10)
if column not in data.columns:
return {'error': f'{column} 不存在'}
try:
# 准备数据
value_counts = data[column].value_counts()
if len(value_counts) > top_n:
# 只保留前 N 个,其余归为"其他"
top_values = value_counts.head(top_n)
other_sum = value_counts.iloc[top_n:].sum()
plot_data = pd.concat([top_values, pd.Series({'其他': other_sum})])
else:
plot_data = value_counts
# 创建图表
fig, ax = plt.subplots(figsize=(10, 8))
colors = plt.cm.Set3(range(len(plot_data)))
wedges, texts, autotexts = ax.pie(
plot_data,
labels=plot_data.index,
autopct='%1.1f%%',
colors=colors,
startangle=90
)
# 设置文本样式
for text in texts:
text.set_fontsize(10)
for autotext in autotexts:
autotext.set_color('white')
autotext.set_fontweight('bold')
autotext.set_fontsize(9)
ax.set_title(title, fontsize=14, fontweight='bold')
plt.tight_layout()
# 确保输出目录存在
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
# 保存图表
plt.savefig(output_path, dpi=100, bbox_inches='tight')
plt.close(fig)
return {
'success': True,
'chart_path': output_path,
'chart_type': 'pie',
'categories': len(plot_data),
'column': column
}
except Exception as e:
return {'error': f'生成饼图失败: {str(e)}'}
def is_applicable(self, data_profile: DataProfile) -> bool:
"""适用于所有数据。"""
return True
class CreateHeatmapTool(AnalysisTool):
"""创建热力图工具。"""
@property
def name(self) -> str:
return "create_heatmap"
@property
def description(self) -> str:
return "创建热力图,用于展示数值矩阵或相关性矩阵。"
@property
def parameters(self) -> Dict[str, Any]:
return {
"type": "object",
"properties": {
"columns": {
"type": "array",
"items": {"type": "string"},
"description": "要分析的列名列表,如果为空则使用所有数值列"
},
"title": {
"type": "string",
"description": "图表标题",
"default": "相关性热力图"
},
"output_path": {
"type": "string",
"description": "输出文件路径",
"default": "heatmap.png"
},
"method": {
"type": "string",
"description": "相关系数方法pearson, spearman, kendall",
"default": "pearson"
}
}
}
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
"""执行热力图生成。"""
columns = kwargs.get('columns', [])
title = kwargs.get('title', '相关性热力图')
output_path = kwargs.get('output_path', 'heatmap.png')
method = kwargs.get('method', 'pearson')
try:
# 如果没有指定列,使用所有数值列
if not columns:
numeric_cols = data.select_dtypes(include=[np.number]).columns.tolist()
else:
numeric_cols = [col for col in columns if col in data.columns]
if len(numeric_cols) < 2:
return {'error': '至少需要两个数值列来创建热力图'}
# 计算相关系数矩阵
corr_matrix = data[numeric_cols].corr(method=method)
# 创建图表
fig, ax = plt.subplots(figsize=(10, 8))
if HAS_SEABORN:
# 使用 seaborn 创建更美观的热力图
sns.heatmap(
corr_matrix,
annot=True,
fmt='.2f',
cmap='coolwarm',
center=0,
square=True,
linewidths=1,
cbar_kws={"shrink": 0.8},
ax=ax
)
else:
# 使用 matplotlib 创建基本热力图
im = ax.imshow(corr_matrix, cmap='coolwarm', aspect='auto', vmin=-1, vmax=1)
ax.set_xticks(range(len(corr_matrix.columns)))
ax.set_yticks(range(len(corr_matrix.columns)))
ax.set_xticklabels(corr_matrix.columns, rotation=45, ha='right')
ax.set_yticklabels(corr_matrix.columns)
# 添加数值标注
for i in range(len(corr_matrix.columns)):
for j in range(len(corr_matrix.columns)):
text = ax.text(j, i, f'{corr_matrix.iloc[i, j]:.2f}',
ha="center", va="center", color="black", fontsize=9)
plt.colorbar(im, ax=ax, shrink=0.8)
ax.set_title(title, fontsize=14, fontweight='bold')
plt.tight_layout()
# 确保输出目录存在
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
# 保存图表
plt.savefig(output_path, dpi=100, bbox_inches='tight')
plt.close(fig)
return {
'success': True,
'chart_path': output_path,
'chart_type': 'heatmap',
'columns': numeric_cols,
'method': method
}
except Exception as e:
return {'error': f'生成热力图失败: {str(e)}'}
def is_applicable(self, data_profile: DataProfile) -> bool:
"""适用于包含至少两个数值列的数据。"""
numeric_cols = [col for col in data_profile.columns if col.dtype == 'numeric']
return len(numeric_cols) >= 2

View File

@@ -1,4 +0,0 @@
@echo off
echo Starting IOV Data Analysis Agent...
python bootstrap.py
pause

140
templates/iot_ops_report.md Normal file
View File

@@ -0,0 +1,140 @@
# 《XX品牌车联网运维分析报告》
## 1. 整体问题分布与效率分析
### 1.1 工单类型分布与趋势
{总工单数}单。
其中:
- TSP问题{数量}单 ({占比}%)
- APP问题{数量}单 ({占比}%)
- DK问题{数量}单 ({占比}%)
- 咨询类:{数量}单 ({占比}%)
> (可增加环比变化趋势)
---
### 1.2 问题解决效率分析
> (后续可增加环比变化趋势,如工单总流转时间、环比增长趋势图)
| 工单类型 | 总数量 | 一线处理数量 | 反馈二线数量 | 平均时长(h) | 中位数(h) | 一次解决率(%) | TSP处理次数 |
| --- | --- | --- | --- | --- | --- | --- | --- |
| TSP问题 | {数值} | | | {数值} | {数值} | {数值} | {数值} |
| APP问题 | {数值} | | | {数值} | {数值} | {数值} | {数值} |
| DK问题 | {数值} | | | {数值} | {数值} | {数值} | {数值} |
| 咨询类 | {数值} | | | {数值} | {数值} | {数值} | {数值} |
| 合计 | | | | | | | |
---
### 1.3 问题车型分布
---
## 2. 各类问题专题分析
### 2.1 TSP问题专题
当月总体情况概述:
| 工单类型 | 总数量 | 海外一线处理数量 | 国内二线数量 | 平均时长(h) | 中位数(h) |
| --- | --- | --- | --- | --- | --- |
| TSP问题 | {数值} | | | {数值} | {数值} |
#### 2.1.1 TSP问题二级分类+三级分布
#### 2.1.2 TOP问题
| 高频问题简述 | 关键词示例 | 原因 | 处理方式 | 占比约 |
| --- | --- | --- | --- | --- |
| 网络超时/偶发延迟 | ack超时、请求超时、一直转圈 | | | {数值} |
| 车辆唤醒失败 | 唤醒失败、深度睡眠、TBOX未唤醒 | | | {数值} |
| 控制器反馈失败 | 控制器反馈状态失败、轻微故障 | | | {数值} |
| TBOX不在线 | 卡不在线、注册异常 | | | {数值} |
> 聚类分析文件(需要输出):[4-1TSP问题聚类.xlsx]
---
### 2.2 APP问题专题
当月总体情况概述:
| 工单类型 | 总数量 | 一线处理数量 | 反馈二线数量 | 一线平均处理时长(h) | 二线平均处理时长(h) | 平均时长(h) | 中位数(h) |
| --- | --- | --- | --- | --- | --- | --- | --- |
| APP问题 | {数值} | | | {数值} | {数值} | {数值} | {数值} |
#### 2.2.1 APP问题二级分类分布
#### 2.2.2 TOP问题
| 高频问题简述 | 关键词示例 | 原因 | 处理方式 | 数量 | 占比约 |
| --- | --- | --- | --- | --- | --- |
| 问题1 | 关键词1、2、3 | | | {数值} | {数值} |
| 问题2 | 关键词1、2、3 | | | {数值} | {数值} |
| 问题3 | 关键词1、2、3 | | | {数值} | {数值} |
| 问题4 | 关键词1、2、3 | | | {数值} | {数值} |
> 聚类分析文件(需要输出):[4-2APP问题聚类.xlsx]
---
### 2.3 TBOX问题专题
> 总流转时间和环比增长趋势(可参考柱状+折线组合图)
#### 2.3.1 TBOX问题二级分类分布
#### 2.3.2 TOP问题
| 高频问题简述 | 关键词示例 | 原因 | 处理方式 | 占比约 |
| --- | --- | --- | --- | --- |
| 问题1 | 关键词1、2、3 | | | {数值} |
| 问题2 | 关键词1、2、3 | | | {数值} |
| 问题3 | 关键词1、2、3 | | | {数值} |
| 问题4 | 关键词1、2、3 | | | {数值} |
| 问题5 | 关键词1、2、3 | | | {数值} |
> 聚类分析文件:[4-3TBOX问题聚类.xlsx]
---
### 2.4 DMC专题
> 总流转时间和环比增长趋势(可参考柱状+折线组合图)
#### 2.4.1 DMC类二级分类分布与解决时长
#### 2.4.2 TOP问题
| 高频问题简述 | 关键词示例 | 原因 | 处理方式 | 占比约 |
| --- | --- | --- | --- | --- |
| 问题1 | 关键词1、2、3 | | | {数值} |
| 问题2 | 关键词1、2、3 | | | {数值} |
> 聚类分析文件(需要输出):[4-4DMC问题处理.xlsx]
---
### 2.5 咨询类专题
> 总流转时间和环比增长趋势(可参考柱状+折线组合图)
#### 2.5.1 咨询类二级分类分布与解决时长
#### 2.5.2 TOP咨询
| 高频问题简述 | 关键词示例 | 原因 | 处理方式 | 占比约 |
| --- | --- | --- | --- | --- |
| 问题1 | 关键词1、2、3 | | | {数值} |
| 问题1 | 关键词1、2、3 | | | {数值} |
> 咨询类文件(需要输出):[4-5咨询类问题处理.xlsx]
---
## 3. 建议与附件
- 工单客诉详情见附件:

View File

@@ -1,145 +0,0 @@
# Test Results Summary - Task 22 Final Checkpoint
## Overall Results
- **Total Tests**: 328
- **Passed**: 314 (95.7%)
- **Failed**: 14 (4.3%)
- **Execution Time**: 182.78s (3:02)
## Failed Tests Analysis
### 1. Property-Based Test Failures (3 tests)
#### test_data_access_properties.py::test_data_profile_completeness
- **Issue**: `hypothesis.errors.FailedHealthCheck` - Generated inputs consumed too much entropy
- **Root Cause**: Data generation strategy creates too large datasets
- **Fix Needed**: Add `suppress_health_check=[HealthCheck.data_too_large]` to settings
#### test_data_understanding_properties.py::test_data_type_inference
- **Issue**: `TypeError: understand_data() got an unexpected keyword argument 'file_path'`
- **Root Cause**: Function signature mismatch in test
- **Fix Needed**: Update test to match actual function signature
#### test_data_understanding_properties.py::test_data_profile_completeness
- **Issue**: Same as above - `TypeError: understand_data() got an unexpected keyword argument 'file_path'`
- **Fix Needed**: Update test to match actual function signature
#### test_tools_properties.py::test_tool_output_filtering
- **Issue**: `hypothesis.errors.FailedHealthCheck` - Generated inputs consumed too much entropy
- **Fix Needed**: Add `suppress_health_check=[HealthCheck.data_too_large]` to settings
### 2. Integration Test Failures (7 tests)
#### test_integration.py::TestEndToEndAnalysis (4 tests)
- **Issue**: `AssertionError: 分析失败: [Errno 13] Permission denied`
- **Root Cause**: Permission denied when accessing temp directory
- **Tests Affected**:
- test_complete_analysis_without_requirement
- test_analysis_with_requirement
- test_template_based_analysis
- test_different_data_types
- **Fix Needed**: Use proper temp directory with write permissions
#### test_integration.py::TestOrchestrator::test_orchestrator_stages
- **Issue**: `assert None is not None`
- **Root Cause**: Orchestrator not returning expected result
- **Fix Needed**: Debug orchestrator implementation
#### test_integration.py::TestProgressTracking::test_progress_callback
- **Issue**: `assert 4 == 5` - Progress callback not called expected number of times
- **Fix Needed**: Verify progress tracking implementation
#### test_integration.py::TestOutputFiles::test_report_file_creation
- **Issue**: `assert False is True` - Report file not created
- **Root Cause**: Likely related to permission issues
- **Fix Needed**: Ensure proper file creation permissions
### 3. Performance Test Failures (3 tests)
#### test_performance.py::TestDataUnderstandingPerformance::test_large_dataset_performance
- **Issue**: `AssertionError: 大数据集理解耗时 30.44秒超过30秒限制`
- **Root Cause**: Performance slightly exceeds 30-second threshold (30.44s)
- **Status**: Acceptable - only 0.44s over limit, within margin of error
#### test_performance.py::TestFullAnalysisPerformance::test_small_dataset_full_analysis
- **Issue**: `assert False is True`
- **Root Cause**: Full analysis not completing successfully
- **Fix Needed**: Debug full analysis workflow
#### test_performance.py::TestFullAnalysisPerformance::test_large_dataset_full_analysis
- **Issue**: `assert False is True`
- **Root Cause**: Full analysis not completing successfully
- **Fix Needed**: Debug full analysis workflow
## Warnings Summary
### Critical Warnings
1. **DeprecationWarning**: `is_categorical_dtype` is deprecated
- Location: `src/engines/data_understanding.py:82`
- Fix: Use `isinstance(dtype, pd.CategoricalDtype)` instead
2. **FutureWarning**: `'H'` frequency is deprecated
- Location: `tests/test_performance.py:104, 264`
- Fix: Use `'h'` instead of `'H'`
3. **UserWarning**: Could not infer datetime format
- Location: `src/data_access.py:173`, `src/tools/query_tools.py:177`
- Fix: Specify explicit format for `pd.to_datetime()`
## Acceptance Criteria Status
### Scenario 1: 完全自主分析
- ✅ AI 能识别数据类型 (Passed)
- ✅ AI 能推断关键字段的业务含义 (Passed)
- ✅ AI 能自主决定分析维度 (Passed)
- ✅ AI 能生成合理的分析计划 (Passed)
- ⚠️ AI 能执行分析并生成报告 (Integration tests failing due to permissions)
- ✅ 报告包含关键发现和洞察 (Passed)
### Scenario 2: 指定分析方向
- ✅ AI 能理解"健康度"的业务含义 (Passed)
- ✅ AI 能将抽象概念转化为具体指标 (Passed)
- ✅ AI 能根据数据特征选择合适的分析方法 (Passed)
- ✅ AI 能生成针对性的报告 (Passed)
### Scenario 3: 参考模板分析
- ✅ AI 能理解模板的结构和要求 (Passed)
- ✅ AI 能检查数据是否满足模板要求 (Passed)
- ✅ AI 能按模板结构组织报告 (Passed)
- ✅ AI 能灵活调整 (Passed)
### Scenario 4: 迭代深入分析
- ✅ AI 能识别异常或关键发现 (Passed)
- ✅ AI 能自主决定是否需要深入分析 (Passed)
- ✅ AI 能动态调整分析计划 (Passed)
- ✅ AI 能追踪问题的根因 (Passed)
### 工具动态性验收
- ✅ 系统根据数据特征自动启用相关工具 (Passed)
- ✅ 系统根据数据特征自动禁用无关工具 (Passed)
- ✅ AI 能识别需要但缺失的工具 (Passed)
## Recommendations
### High Priority Fixes
1. Fix permission issues in integration tests (use proper temp directories)
2. Fix function signature mismatches in property tests
3. Add health check suppressions for large data tests
### Medium Priority Fixes
1. Update deprecated pandas API calls
2. Fix datetime format warnings
3. Debug full analysis workflow failures
### Low Priority
1. Optimize large dataset performance (currently 30.44s vs 30s limit)
2. Verify progress tracking callback counts
## Conclusion
The system has achieved **95.7% test pass rate** with most core functionality working correctly. The failures are primarily:
- **Environmental issues** (permissions, temp directories)
- **Test configuration issues** (health checks, function signatures)
- **Minor performance issues** (0.44s over threshold)
All core acceptance criteria are met, with only integration test failures due to environmental issues preventing full end-to-end validation.

View File

@@ -1 +0,0 @@
"""Tests for the AI data analysis agent."""

View File

@@ -1,111 +0,0 @@
"""Pytest configuration and fixtures."""
import pytest
from hypothesis import settings, Verbosity
# Configure hypothesis settings
settings.register_profile("default", max_examples=100, verbosity=Verbosity.normal)
settings.register_profile("ci", max_examples=1000, verbosity=Verbosity.verbose)
settings.load_profile("default")
@pytest.fixture
def sample_column_info():
"""Fixture providing a sample ColumnInfo instance."""
from src.models import ColumnInfo
return ColumnInfo(
name='test_column',
dtype='numeric',
missing_rate=0.1,
unique_count=50,
sample_values=[1, 2, 3, 4, 5],
statistics={'mean': 3.0, 'std': 1.5}
)
@pytest.fixture
def sample_data_profile():
"""Fixture providing a sample DataProfile instance."""
from src.models import DataProfile, ColumnInfo
columns = [
ColumnInfo(name='id', dtype='numeric', missing_rate=0.0, unique_count=100),
ColumnInfo(name='status', dtype='categorical', missing_rate=0.0, unique_count=3),
]
return DataProfile(
file_path='test.csv',
row_count=100,
column_count=2,
columns=columns,
inferred_type='ticket',
key_fields={'status': 'ticket status'},
quality_score=85.0,
summary='Test data profile'
)
@pytest.fixture
def sample_analysis_objective():
"""Fixture providing a sample AnalysisObjective instance."""
from src.models import AnalysisObjective
return AnalysisObjective(
name='Test Objective',
description='Test analysis objective',
metrics=['metric1', 'metric2'],
priority=5
)
@pytest.fixture
def sample_requirement_spec(sample_analysis_objective):
"""Fixture providing a sample RequirementSpec instance."""
from src.models import RequirementSpec
return RequirementSpec(
user_input='Test requirement',
objectives=[sample_analysis_objective],
constraints=['no_pii'],
expected_outputs=['report']
)
@pytest.fixture
def sample_analysis_task():
"""Fixture providing a sample AnalysisTask instance."""
from src.models import AnalysisTask
return AnalysisTask(
id='task_1',
name='Test Task',
description='Test analysis task',
priority=5,
dependencies=[],
required_tools=['tool1'],
expected_output='Test output'
)
@pytest.fixture
def sample_analysis_plan(sample_analysis_objective, sample_analysis_task):
"""Fixture providing a sample AnalysisPlan instance."""
from src.models import AnalysisPlan
return AnalysisPlan(
objectives=[sample_analysis_objective],
tasks=[sample_analysis_task],
tool_config={'tool1': 'config1'},
estimated_duration=300
)
@pytest.fixture
def sample_analysis_result():
"""Fixture providing a sample AnalysisResult instance."""
from src.models import AnalysisResult
return AnalysisResult(
task_id='task_1',
task_name='Test Task',
success=True,
data={'count': 100},
visualizations=['chart.png'],
insights=['Key finding'],
execution_time=5.0
)

View File

@@ -1,342 +0,0 @@
"""Unit tests for analysis planning engine."""
import pytest
from src.engines.analysis_planning import (
plan_analysis,
validate_task_dependencies,
_fallback_analysis_planning,
_has_circular_dependency
)
from src.models.data_profile import DataProfile, ColumnInfo
from src.models.requirement_spec import RequirementSpec, AnalysisObjective
from src.models.analysis_plan import AnalysisTask
@pytest.fixture
def sample_data_profile():
"""Create a sample data profile for testing."""
return DataProfile(
file_path='test.csv',
row_count=1000,
column_count=5,
columns=[
ColumnInfo(
name='created_at',
dtype='datetime',
missing_rate=0.0,
unique_count=1000
),
ColumnInfo(
name='status',
dtype='categorical',
missing_rate=0.1,
unique_count=5
),
ColumnInfo(
name='type',
dtype='categorical',
missing_rate=0.0,
unique_count=10
),
ColumnInfo(
name='priority',
dtype='numeric',
missing_rate=0.0,
unique_count=5
),
ColumnInfo(
name='description',
dtype='text',
missing_rate=0.05,
unique_count=950
)
],
inferred_type='ticket',
key_fields={'time': 'created_at', 'status': 'status'},
quality_score=85.0,
summary='Ticket data with 1000 rows'
)
@pytest.fixture
def sample_requirement():
"""Create a sample requirement for testing."""
return RequirementSpec(
user_input="分析工单健康度和趋势",
objectives=[
AnalysisObjective(
name="健康度分析",
description="评估工单处理的健康状况",
metrics=["完成率", "处理效率"],
priority=5
),
AnalysisObjective(
name="趋势分析",
description="分析工单随时间的变化趋势",
metrics=["时间序列", "增长率"],
priority=4
)
]
)
def test_fallback_planning_generates_tasks(sample_data_profile, sample_requirement):
"""Test that fallback planning generates tasks."""
plan = _fallback_analysis_planning(sample_data_profile, sample_requirement)
# Should have tasks
assert len(plan.tasks) > 0
# Should have objectives
assert len(plan.objectives) == len(sample_requirement.objectives)
# Should have estimated duration
assert plan.estimated_duration > 0
def test_fallback_planning_respects_objectives(sample_data_profile, sample_requirement):
"""Test that fallback planning creates tasks based on objectives."""
plan = _fallback_analysis_planning(sample_data_profile, sample_requirement)
# Should have tasks related to health analysis
health_tasks = [t for t in plan.tasks if '健康' in t.name or '质量' in t.name]
assert len(health_tasks) > 0
# Should have tasks related to trend analysis
trend_tasks = [t for t in plan.tasks if '趋势' in t.name or '时间' in t.name]
assert len(trend_tasks) > 0
def test_fallback_planning_with_no_matching_objectives(sample_data_profile):
"""Test fallback planning with generic objectives."""
requirement = RequirementSpec(
user_input="分析数据",
objectives=[
AnalysisObjective(
name="综合分析",
description="全面分析数据",
metrics=[],
priority=3
)
]
)
plan = _fallback_analysis_planning(sample_data_profile, requirement)
# Should still generate at least one task
assert len(plan.tasks) > 0
def test_fallback_planning_with_empty_objectives(sample_data_profile):
"""Test fallback planning with no objectives."""
requirement = RequirementSpec(
user_input="分析数据",
objectives=[]
)
plan = _fallback_analysis_planning(sample_data_profile, requirement)
# Should generate default task
assert len(plan.tasks) > 0
def test_validate_dependencies_valid():
"""Test validation with valid dependencies."""
tasks = [
AnalysisTask(
id="task_1",
name="Task 1",
description="First task",
priority=5,
dependencies=[]
),
AnalysisTask(
id="task_2",
name="Task 2",
description="Second task",
priority=4,
dependencies=["task_1"]
),
AnalysisTask(
id="task_3",
name="Task 3",
description="Third task",
priority=3,
dependencies=["task_1", "task_2"]
)
]
validation = validate_task_dependencies(tasks)
assert validation['valid']
assert validation['forms_dag']
assert not validation['has_circular_dependency']
assert len(validation['missing_dependencies']) == 0
def test_validate_dependencies_with_cycle():
"""Test validation detects circular dependencies."""
tasks = [
AnalysisTask(
id="task_1",
name="Task 1",
description="First task",
priority=5,
dependencies=["task_2"]
),
AnalysisTask(
id="task_2",
name="Task 2",
description="Second task",
priority=4,
dependencies=["task_1"]
)
]
validation = validate_task_dependencies(tasks)
assert not validation['valid']
assert validation['has_circular_dependency']
assert not validation['forms_dag']
def test_validate_dependencies_with_missing():
"""Test validation detects missing dependencies."""
tasks = [
AnalysisTask(
id="task_1",
name="Task 1",
description="First task",
priority=5,
dependencies=["task_999"] # Doesn't exist
)
]
validation = validate_task_dependencies(tasks)
assert not validation['valid']
assert len(validation['missing_dependencies']) > 0
def test_has_circular_dependency_simple_cycle():
"""Test circular dependency detection with simple cycle."""
tasks = [
AnalysisTask(
id="A",
name="Task A",
description="Task A",
priority=3,
dependencies=["B"]
),
AnalysisTask(
id="B",
name="Task B",
description="Task B",
priority=3,
dependencies=["A"]
)
]
assert _has_circular_dependency(tasks)
def test_has_circular_dependency_complex_cycle():
"""Test circular dependency detection with complex cycle."""
tasks = [
AnalysisTask(
id="A",
name="Task A",
description="Task A",
priority=3,
dependencies=["B"]
),
AnalysisTask(
id="B",
name="Task B",
description="Task B",
priority=3,
dependencies=["C"]
),
AnalysisTask(
id="C",
name="Task C",
description="Task C",
priority=3,
dependencies=["A"] # Cycle: A -> B -> C -> A
)
]
assert _has_circular_dependency(tasks)
def test_has_circular_dependency_no_cycle():
"""Test circular dependency detection with no cycle."""
tasks = [
AnalysisTask(
id="A",
name="Task A",
description="Task A",
priority=3,
dependencies=[]
),
AnalysisTask(
id="B",
name="Task B",
description="Task B",
priority=3,
dependencies=["A"]
),
AnalysisTask(
id="C",
name="Task C",
description="Task C",
priority=3,
dependencies=["A", "B"]
)
]
assert not _has_circular_dependency(tasks)
def test_task_priority_range(sample_data_profile, sample_requirement):
"""Test that all generated tasks have valid priority range."""
plan = _fallback_analysis_planning(sample_data_profile, sample_requirement)
for task in plan.tasks:
assert 1 <= task.priority <= 5, \
f"Task {task.id} has invalid priority {task.priority}"
def test_task_unique_ids(sample_data_profile, sample_requirement):
"""Test that all tasks have unique IDs."""
plan = _fallback_analysis_planning(sample_data_profile, sample_requirement)
task_ids = [task.id for task in plan.tasks]
assert len(task_ids) == len(set(task_ids)), "Task IDs should be unique"
def test_plan_has_timestamps(sample_data_profile, sample_requirement):
"""Test that plan has creation and update timestamps."""
plan = _fallback_analysis_planning(sample_data_profile, sample_requirement)
assert plan.created_at is not None
assert plan.updated_at is not None
def test_task_required_tools_is_list(sample_data_profile, sample_requirement):
"""Test that required_tools is always a list."""
plan = _fallback_analysis_planning(sample_data_profile, sample_requirement)
for task in plan.tasks:
assert isinstance(task.required_tools, list), \
f"Task {task.id} required_tools should be a list"
def test_task_dependencies_is_list(sample_data_profile, sample_requirement):
"""Test that dependencies is always a list."""
plan = _fallback_analysis_planning(sample_data_profile, sample_requirement)
for task in plan.tasks:
assert isinstance(task.dependencies, list), \
f"Task {task.id} dependencies should be a list"

View File

@@ -1,265 +0,0 @@
"""Property-based tests for analysis planning engine."""
import pytest
from hypothesis import given, strategies as st, settings
from src.engines.analysis_planning import (
plan_analysis,
validate_task_dependencies,
_fallback_analysis_planning,
_has_circular_dependency
)
from src.models.data_profile import DataProfile, ColumnInfo
from src.models.requirement_spec import RequirementSpec, AnalysisObjective
from src.models.analysis_plan import AnalysisTask
# Strategies for generating test data
@st.composite
def column_info_strategy(draw):
"""Generate random ColumnInfo."""
name = draw(st.text(min_size=1, max_size=20, alphabet=st.characters(whitelist_categories=('L', 'N'))))
dtype = draw(st.sampled_from(['numeric', 'categorical', 'datetime', 'text']))
missing_rate = draw(st.floats(min_value=0.0, max_value=1.0))
unique_count = draw(st.integers(min_value=1, max_value=1000))
return ColumnInfo(
name=name,
dtype=dtype,
missing_rate=missing_rate,
unique_count=unique_count,
sample_values=[],
statistics={}
)
@st.composite
def data_profile_strategy(draw):
"""Generate random DataProfile."""
row_count = draw(st.integers(min_value=10, max_value=100000))
columns = draw(st.lists(column_info_strategy(), min_size=2, max_size=20))
inferred_type = draw(st.sampled_from(['ticket', 'sales', 'user', 'unknown']))
quality_score = draw(st.floats(min_value=0.0, max_value=100.0))
return DataProfile(
file_path='test.csv',
row_count=row_count,
column_count=len(columns),
columns=columns,
inferred_type=inferred_type,
key_fields={},
quality_score=quality_score,
summary=f"Test data with {len(columns)} columns"
)
@st.composite
def requirement_spec_strategy(draw):
"""Generate random RequirementSpec."""
user_input = draw(st.text(min_size=5, max_size=100))
num_objectives = draw(st.integers(min_value=1, max_value=5))
objectives = []
for i in range(num_objectives):
obj = AnalysisObjective(
name=f"Objective {i+1}",
description=draw(st.text(min_size=10, max_size=100)),
metrics=draw(st.lists(st.text(min_size=3, max_size=20), min_size=1, max_size=5)),
priority=draw(st.integers(min_value=1, max_value=5))
)
objectives.append(obj)
return RequirementSpec(
user_input=user_input,
objectives=objectives
)
# Feature: true-ai-agent, Property 6: 动态任务生成
@given(
data_profile=data_profile_strategy(),
requirement=requirement_spec_strategy()
)
@settings(max_examples=20, deadline=None)
def test_dynamic_task_generation(data_profile, requirement):
"""
Property 6: For any data profile and requirement spec, the analysis
planning engine should be able to generate a non-empty task list, with
each task containing unique ID, description, priority, and required tools.
Validates: 场景1验收.2, FR-3.1
"""
# Use fallback to avoid API dependency
plan = _fallback_analysis_planning(data_profile, requirement)
# Verify: Should have tasks
assert len(plan.tasks) > 0, "Should generate at least one task"
# Verify: Each task should have required fields
task_ids = set()
for task in plan.tasks:
# Unique ID
assert task.id not in task_ids, f"Task ID {task.id} is not unique"
task_ids.add(task.id)
# Required fields
assert len(task.name) > 0, "Task name should not be empty"
assert len(task.description) > 0, "Task description should not be empty"
assert 1 <= task.priority <= 5, f"Task priority {task.priority} should be between 1 and 5"
assert isinstance(task.required_tools, list), "Required tools should be a list"
assert isinstance(task.dependencies, list), "Dependencies should be a list"
assert task.status in ['pending', 'running', 'completed', 'failed', 'skipped'], \
f"Invalid task status: {task.status}"
# Verify: Plan should have objectives
assert len(plan.objectives) > 0, "Plan should have objectives"
# Verify: Estimated duration should be non-negative
assert plan.estimated_duration >= 0, "Estimated duration should be non-negative"
# Feature: true-ai-agent, Property 7: 任务依赖一致性
@given(
data_profile=data_profile_strategy(),
requirement=requirement_spec_strategy()
)
@settings(max_examples=20, deadline=None)
def test_task_dependency_consistency(data_profile, requirement):
"""
Property 7: For any generated analysis plan, all task dependencies should
form a directed acyclic graph (DAG), with no circular dependencies.
Validates: FR-3.1
"""
# Use fallback to avoid API dependency
plan = _fallback_analysis_planning(data_profile, requirement)
# Verify: No circular dependencies
assert not _has_circular_dependency(plan.tasks), \
"Task dependencies should not form a cycle"
# Verify: All dependencies exist
task_ids = {task.id for task in plan.tasks}
for task in plan.tasks:
for dep_id in task.dependencies:
assert dep_id in task_ids, \
f"Task {task.id} depends on non-existent task {dep_id}"
assert dep_id != task.id, \
f"Task {task.id} should not depend on itself"
# Verify: Validation function agrees
validation = validate_task_dependencies(plan.tasks)
assert validation['valid'], "Task dependencies should be valid"
assert validation['forms_dag'], "Task dependencies should form a DAG"
assert not validation['has_circular_dependency'], "Should not have circular dependencies"
assert len(validation['missing_dependencies']) == 0, "Should not have missing dependencies"
# Feature: true-ai-agent, Property 6: 动态任务生成 (priority ordering)
@given(
data_profile=data_profile_strategy(),
requirement=requirement_spec_strategy()
)
@settings(max_examples=20, deadline=None)
def test_task_priority_ordering(data_profile, requirement):
"""
Property 6 (extended): Tasks should respect objective priorities.
High-priority objectives should generate high-priority tasks.
Validates: FR-3.2
"""
# Use fallback to avoid API dependency
plan = _fallback_analysis_planning(data_profile, requirement)
# Verify: All tasks have valid priorities
for task in plan.tasks:
assert 1 <= task.priority <= 5, \
f"Task priority {task.priority} should be between 1 and 5"
# Verify: If objectives have high priority, at least some tasks should too
max_obj_priority = max(obj.priority for obj in plan.objectives)
if max_obj_priority >= 4:
# Should have at least one high-priority task
high_priority_tasks = [t for t in plan.tasks if t.priority >= 4]
# This is a soft requirement, so we just check structure
assert all(1 <= t.priority <= 5 for t in plan.tasks)
# Test circular dependency detection
@given(
num_tasks=st.integers(min_value=2, max_value=10)
)
@settings(max_examples=10, deadline=None)
def test_circular_dependency_detection(num_tasks):
"""
Test that circular dependency detection works correctly.
"""
# Create tasks with no dependencies (should be valid)
tasks = [
AnalysisTask(
id=f"task_{i}",
name=f"Task {i}",
description=f"Description {i}",
priority=3,
dependencies=[]
)
for i in range(num_tasks)
]
# Should not have circular dependencies
assert not _has_circular_dependency(tasks)
# Create a simple cycle: task_0 -> task_1 -> task_0
if num_tasks >= 2:
tasks_with_cycle = [
AnalysisTask(
id="task_0",
name="Task 0",
description="Description 0",
priority=3,
dependencies=["task_1"]
),
AnalysisTask(
id="task_1",
name="Task 1",
description="Description 1",
priority=3,
dependencies=["task_0"]
)
]
# Should detect the cycle
assert _has_circular_dependency(tasks_with_cycle)
# Test dependency validation
def test_dependency_validation_with_missing_deps():
"""Test validation detects missing dependencies."""
tasks = [
AnalysisTask(
id="task_1",
name="Task 1",
description="Description 1",
priority=3,
dependencies=["task_2", "task_999"] # task_999 doesn't exist
),
AnalysisTask(
id="task_2",
name="Task 2",
description="Description 2",
priority=3,
dependencies=[]
)
]
validation = validate_task_dependencies(tasks)
# Should not be valid
assert not validation['valid']
# Should have missing dependencies
assert len(validation['missing_dependencies']) > 0
# Should identify task_999 as missing
missing_dep_ids = [md['missing_dep'] for md in validation['missing_dependencies']]
assert 'task_999' in missing_dep_ids

View File

@@ -1,430 +0,0 @@
"""配置管理模块的单元测试。"""
import os
import json
import pytest
from pathlib import Path
from unittest.mock import patch
from src.config import (
LLMConfig,
PerformanceConfig,
OutputConfig,
Config,
get_config,
set_config,
load_config_from_env,
load_config_from_file
)
class TestLLMConfig:
"""测试 LLM 配置。"""
def test_default_config(self):
"""测试默认配置。"""
config = LLMConfig(api_key="test_key")
assert config.provider == "openai"
assert config.api_key == "test_key"
assert config.base_url == "https://api.openai.com/v1"
assert config.model == "gpt-4"
assert config.timeout == 120
assert config.max_retries == 3
assert config.temperature == 0.7
assert config.max_tokens is None
def test_custom_config(self):
"""测试自定义配置。"""
config = LLMConfig(
provider="gemini",
api_key="gemini_key",
base_url="https://gemini.api",
model="gemini-pro",
timeout=60,
max_retries=5,
temperature=0.5,
max_tokens=1000
)
assert config.provider == "gemini"
assert config.api_key == "gemini_key"
assert config.base_url == "https://gemini.api"
assert config.model == "gemini-pro"
assert config.timeout == 60
assert config.max_retries == 5
assert config.temperature == 0.5
assert config.max_tokens == 1000
def test_empty_api_key(self):
"""测试空 API key。"""
with pytest.raises(ValueError, match="API key 不能为空"):
LLMConfig(api_key="")
def test_invalid_provider(self):
"""测试无效的 provider。"""
with pytest.raises(ValueError, match="不支持的 LLM provider"):
LLMConfig(api_key="test", provider="invalid")
def test_invalid_timeout(self):
"""测试无效的 timeout。"""
with pytest.raises(ValueError, match="timeout 必须大于 0"):
LLMConfig(api_key="test", timeout=0)
def test_invalid_max_retries(self):
"""测试无效的 max_retries。"""
with pytest.raises(ValueError, match="max_retries 不能为负数"):
LLMConfig(api_key="test", max_retries=-1)
class TestPerformanceConfig:
"""测试性能配置。"""
def test_default_config(self):
"""测试默认配置。"""
config = PerformanceConfig()
assert config.agent_max_rounds == 20
assert config.agent_timeout == 300
assert config.tool_max_query_rows == 10000
assert config.tool_execution_timeout == 60
assert config.data_max_rows == 1000000
assert config.data_sample_threshold == 1000000
assert config.max_concurrent_tasks == 1
def test_custom_config(self):
"""测试自定义配置。"""
config = PerformanceConfig(
agent_max_rounds=10,
agent_timeout=600,
tool_max_query_rows=5000,
tool_execution_timeout=30,
data_max_rows=500000,
data_sample_threshold=500000,
max_concurrent_tasks=2
)
assert config.agent_max_rounds == 10
assert config.agent_timeout == 600
assert config.tool_max_query_rows == 5000
assert config.tool_execution_timeout == 30
assert config.data_max_rows == 500000
assert config.data_sample_threshold == 500000
assert config.max_concurrent_tasks == 2
def test_invalid_agent_max_rounds(self):
"""测试无效的 agent_max_rounds。"""
with pytest.raises(ValueError, match="agent_max_rounds 必须大于 0"):
PerformanceConfig(agent_max_rounds=0)
def test_invalid_tool_max_query_rows(self):
"""测试无效的 tool_max_query_rows。"""
with pytest.raises(ValueError, match="tool_max_query_rows 必须大于 0"):
PerformanceConfig(tool_max_query_rows=-1)
class TestOutputConfig:
"""测试输出配置。"""
def test_default_config(self):
"""测试默认配置。"""
config = OutputConfig()
assert config.output_dir == "output"
assert config.log_dir == "output"
assert config.chart_dir == str(Path("output") / "charts")
assert config.report_filename == "analysis_report.md"
assert config.log_level == "INFO"
assert config.log_to_file is True
assert config.log_to_console is True
def test_custom_config(self):
"""测试自定义配置。"""
config = OutputConfig(
output_dir="results",
log_dir="logs",
chart_dir="charts",
report_filename="report.md",
log_level="DEBUG",
log_to_file=False,
log_to_console=True
)
assert config.output_dir == "results"
assert config.log_dir == "logs"
assert config.chart_dir == "charts"
assert config.report_filename == "report.md"
assert config.log_level == "DEBUG"
assert config.log_to_file is False
assert config.log_to_console is True
def test_invalid_log_level(self):
"""测试无效的 log_level。"""
with pytest.raises(ValueError, match="不支持的 log_level"):
OutputConfig(log_level="INVALID")
def test_get_paths(self):
"""测试路径获取方法。"""
config = OutputConfig(
output_dir="results",
log_dir="logs",
chart_dir="charts"
)
assert config.get_output_path() == Path("results")
assert config.get_log_path() == Path("logs")
assert config.get_chart_path() == Path("charts")
assert config.get_report_path() == Path("results/analysis_report.md")
class TestConfig:
"""测试系统配置。"""
def test_default_config(self):
"""测试默认配置。"""
config = Config(
llm=LLMConfig(api_key="test_key")
)
assert config.llm.api_key == "test_key"
assert config.performance.agent_max_rounds == 20
assert config.output.output_dir == "output"
assert config.code_repo_enable_reuse is True
def test_from_env(self):
"""测试从环境变量加载配置。"""
env_vars = {
"LLM_PROVIDER": "openai",
"OPENAI_API_KEY": "env_test_key",
"OPENAI_BASE_URL": "https://test.api",
"OPENAI_MODEL": "gpt-3.5-turbo",
"AGENT_MAX_ROUNDS": "15",
"AGENT_OUTPUT_DIR": "test_output",
"TOOL_MAX_QUERY_ROWS": "5000",
"CODE_REPO_ENABLE_REUSE": "false"
}
with patch.dict(os.environ, env_vars, clear=True):
config = Config.from_env()
assert config.llm.provider == "openai"
assert config.llm.api_key == "env_test_key"
assert config.llm.base_url == "https://test.api"
assert config.llm.model == "gpt-3.5-turbo"
assert config.performance.agent_max_rounds == 15
assert config.performance.tool_max_query_rows == 5000
assert config.output.output_dir == "test_output"
assert config.code_repo_enable_reuse is False
def test_from_env_gemini(self):
"""测试从环境变量加载 Gemini 配置。"""
env_vars = {
"LLM_PROVIDER": "gemini",
"GEMINI_API_KEY": "gemini_key",
"GEMINI_BASE_URL": "https://gemini.api",
"GEMINI_MODEL": "gemini-pro"
}
with patch.dict(os.environ, env_vars, clear=True):
config = Config.from_env()
assert config.llm.provider == "gemini"
assert config.llm.api_key == "gemini_key"
assert config.llm.base_url == "https://gemini.api"
assert config.llm.model == "gemini-pro"
def test_from_dict(self):
"""测试从字典加载配置。"""
config_dict = {
"llm": {
"provider": "openai",
"api_key": "dict_test_key",
"base_url": "https://dict.api",
"model": "gpt-4",
"timeout": 90,
"max_retries": 2,
"temperature": 0.5,
"max_tokens": 2000
},
"performance": {
"agent_max_rounds": 25,
"tool_max_query_rows": 8000
},
"output": {
"output_dir": "dict_output",
"log_level": "DEBUG"
},
"code_repo_enable_reuse": False
}
config = Config.from_dict(config_dict)
assert config.llm.api_key == "dict_test_key"
assert config.llm.base_url == "https://dict.api"
assert config.llm.timeout == 90
assert config.llm.max_retries == 2
assert config.llm.temperature == 0.5
assert config.llm.max_tokens == 2000
assert config.performance.agent_max_rounds == 25
assert config.performance.tool_max_query_rows == 8000
assert config.output.output_dir == "dict_output"
assert config.output.log_level == "DEBUG"
assert config.code_repo_enable_reuse is False
def test_from_file(self, tmp_path):
"""测试从文件加载配置。"""
config_file = tmp_path / "test_config.json"
config_dict = {
"llm": {
"provider": "openai",
"api_key": "file_test_key",
"model": "gpt-4"
},
"performance": {
"agent_max_rounds": 30
}
}
with open(config_file, 'w') as f:
json.dump(config_dict, f)
config = Config.from_file(str(config_file))
assert config.llm.api_key == "file_test_key"
assert config.llm.model == "gpt-4"
assert config.performance.agent_max_rounds == 30
def test_from_file_not_found(self):
"""测试加载不存在的配置文件。"""
with pytest.raises(FileNotFoundError):
Config.from_file("nonexistent.json")
def test_to_dict(self):
"""测试转换为字典。"""
config = Config(
llm=LLMConfig(
api_key="test_key",
model="gpt-4"
),
performance=PerformanceConfig(
agent_max_rounds=15
),
output=OutputConfig(
output_dir="test_output"
)
)
config_dict = config.to_dict()
assert config_dict["llm"]["api_key"] == "***" # API key 应该被隐藏
assert config_dict["llm"]["model"] == "gpt-4"
assert config_dict["performance"]["agent_max_rounds"] == 15
assert config_dict["output"]["output_dir"] == "test_output"
def test_save_to_file(self, tmp_path):
"""测试保存配置到文件。"""
config_file = tmp_path / "saved_config.json"
config = Config(
llm=LLMConfig(api_key="test_key"),
performance=PerformanceConfig(agent_max_rounds=15)
)
config.save_to_file(str(config_file))
assert config_file.exists()
with open(config_file, 'r') as f:
saved_dict = json.load(f)
assert saved_dict["llm"]["api_key"] == "***"
assert saved_dict["performance"]["agent_max_rounds"] == 15
def test_validate_success(self):
"""测试配置验证成功。"""
config = Config(
llm=LLMConfig(api_key="test_key")
)
assert config.validate() is True
def test_validate_missing_api_key(self):
"""测试配置验证失败(缺少 API key"""
config = Config(
llm=LLMConfig(api_key="test_key")
)
config.llm.api_key = "" # 手动清空
assert config.validate() is False
class TestGlobalConfig:
"""测试全局配置管理。"""
def test_get_config(self):
"""测试获取全局配置。"""
# 重置全局配置
set_config(None)
# 模拟环境变量
env_vars = {
"OPENAI_API_KEY": "global_test_key"
}
with patch.dict(os.environ, env_vars, clear=True):
config = get_config()
assert config is not None
assert config.llm.api_key == "global_test_key"
def test_set_config(self):
"""测试设置全局配置。"""
custom_config = Config(
llm=LLMConfig(api_key="custom_key")
)
set_config(custom_config)
config = get_config()
assert config.llm.api_key == "custom_key"
def test_load_config_from_env(self):
"""测试从环境变量加载全局配置。"""
env_vars = {
"OPENAI_API_KEY": "env_global_key",
"AGENT_MAX_ROUNDS": "25"
}
with patch.dict(os.environ, env_vars, clear=True):
config = load_config_from_env()
assert config.llm.api_key == "env_global_key"
assert config.performance.agent_max_rounds == 25
# 验证全局配置已更新
global_config = get_config()
assert global_config.llm.api_key == "env_global_key"
def test_load_config_from_file(self, tmp_path):
"""测试从文件加载全局配置。"""
config_file = tmp_path / "global_config.json"
config_dict = {
"llm": {
"provider": "openai",
"api_key": "file_global_key",
"model": "gpt-4"
}
}
with open(config_file, 'w') as f:
json.dump(config_dict, f)
config = load_config_from_file(str(config_file))
assert config.llm.api_key == "file_global_key"
# 验证全局配置已更新
global_config = get_config()
assert global_config.llm.api_key == "file_global_key"

View File

@@ -1,268 +0,0 @@
"""数据访问层的单元测试。"""
import pytest
import pandas as pd
import tempfile
import os
from pathlib import Path
from src.data_access import DataAccessLayer, DataLoadError
class TestDataAccessLayer:
"""数据访问层的单元测试。"""
def test_load_utf8_csv(self):
"""测试加载 UTF-8 编码的 CSV 文件。"""
# 创建临时 CSV 文件
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8') as f:
f.write('id,name,value\n')
f.write('1,测试,100\n')
f.write('2,数据,200\n')
temp_file = f.name
try:
# 加载数据
dal = DataAccessLayer.load_from_file(temp_file)
assert dal.shape == (2, 3)
assert 'id' in dal.columns
assert 'name' in dal.columns
assert 'value' in dal.columns
finally:
os.unlink(temp_file)
def test_load_gbk_csv(self):
"""测试加载 GBK 编码的 CSV 文件。"""
# 创建临时 GBK 编码的 CSV 文件
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='gbk') as f:
f.write('编号,名称,数值\n')
f.write('1,测试,100\n')
f.write('2,数据,200\n')
temp_file = f.name
try:
# 加载数据
dal = DataAccessLayer.load_from_file(temp_file)
assert dal.shape == (2, 3)
assert len(dal.columns) == 3
finally:
os.unlink(temp_file)
def test_load_empty_file(self):
"""测试加载空文件。"""
# 创建空的 CSV 文件
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8') as f:
f.write('id,name\n') # 只有表头,没有数据
temp_file = f.name
try:
# 应该抛出 DataLoadError
with pytest.raises(DataLoadError, match="为空"):
DataAccessLayer.load_from_file(temp_file)
finally:
os.unlink(temp_file)
def test_load_invalid_file(self):
"""测试加载不存在的文件。"""
with pytest.raises(DataLoadError):
DataAccessLayer.load_from_file('nonexistent_file.csv')
def test_get_profile_basic(self):
"""测试生成基本数据画像。"""
# 创建测试数据
df = pd.DataFrame({
'id': [1, 2, 3, 4, 5],
'name': ['A', 'B', 'C', 'D', 'E'],
'value': [10, 20, 30, 40, 50],
'status': ['open', 'closed', 'open', 'closed', 'open']
})
dal = DataAccessLayer(df, file_path='test.csv')
profile = dal.get_profile()
# 验证基本信息
assert profile.file_path == 'test.csv'
assert profile.row_count == 5
assert profile.column_count == 4
assert len(profile.columns) == 4
# 验证列信息
col_names = [col.name for col in profile.columns]
assert 'id' in col_names
assert 'name' in col_names
assert 'value' in col_names
assert 'status' in col_names
def test_get_profile_with_missing_values(self):
"""测试包含缺失值的数据画像。"""
df = pd.DataFrame({
'id': [1, 2, 3, 4, 5],
'value': [10, None, 30, None, 50]
})
dal = DataAccessLayer(df)
profile = dal.get_profile()
# 查找 value 列
value_col = next(col for col in profile.columns if col.name == 'value')
# 验证缺失率
assert value_col.missing_rate == 0.4 # 2/5 = 0.4
def test_column_type_inference_numeric(self):
"""测试数值类型推断。"""
df = pd.DataFrame({
'int_col': [1, 2, 3, 4, 5],
'float_col': [1.1, 2.2, 3.3, 4.4, 5.5]
})
dal = DataAccessLayer(df)
profile = dal.get_profile()
int_col = next(col for col in profile.columns if col.name == 'int_col')
float_col = next(col for col in profile.columns if col.name == 'float_col')
assert int_col.dtype == 'numeric'
assert float_col.dtype == 'numeric'
# 验证统计信息
assert 'mean' in int_col.statistics
assert 'std' in int_col.statistics
assert 'min' in int_col.statistics
assert 'max' in int_col.statistics
def test_column_type_inference_categorical(self):
"""测试分类类型推断。"""
df = pd.DataFrame({
'status': ['open', 'closed', 'open', 'closed', 'open'] * 20
})
dal = DataAccessLayer(df)
profile = dal.get_profile()
status_col = profile.columns[0]
assert status_col.dtype == 'categorical'
# 验证统计信息
assert 'top_values' in status_col.statistics
assert 'num_categories' in status_col.statistics
def test_column_type_inference_datetime(self):
"""测试日期时间类型推断。"""
df = pd.DataFrame({
'date': pd.date_range('2020-01-01', periods=10)
})
dal = DataAccessLayer(df)
profile = dal.get_profile()
date_col = profile.columns[0]
assert date_col.dtype == 'datetime'
def test_sample_values_limit(self):
"""测试示例值数量限制。"""
df = pd.DataFrame({
'id': list(range(100))
})
dal = DataAccessLayer(df)
profile = dal.get_profile()
id_col = profile.columns[0]
# 示例值应该最多5个
assert len(id_col.sample_values) <= 5
def test_sanitize_result_dataframe(self):
"""测试结果过滤 - DataFrame。"""
df = pd.DataFrame({
'id': list(range(200)),
'value': list(range(200))
})
dal = DataAccessLayer(df)
# 模拟工具返回大量数据
result = {'data': df}
sanitized = dal._sanitize_result(result)
# 验证返回的数据应该被截断到100行
assert len(sanitized['data']) <= 100
def test_sanitize_result_series(self):
"""测试结果过滤 - Series。"""
df = pd.DataFrame({
'id': list(range(200))
})
dal = DataAccessLayer(df)
# 模拟工具返回 Series
result = {'data': df['id']}
sanitized = dal._sanitize_result(result)
# 验证:返回的数据应该被截断
assert len(sanitized['data']) <= 100
def test_large_dataset_sampling(self):
"""测试大数据集采样。"""
# 创建超过100万行的临时文件
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8') as f:
f.write('id,value\n')
# 写入少量数据用于测试(实际测试大数据集会很慢)
for i in range(1000):
f.write(f'{i},{i*10}\n')
temp_file = f.name
try:
dal = DataAccessLayer.load_from_file(temp_file)
# 验证数据被加载
assert dal.shape[0] == 1000
finally:
os.unlink(temp_file)
class TestDataAccessLayerIntegration:
"""数据访问层的集成测试。"""
def test_end_to_end_workflow(self):
"""测试端到端工作流程。"""
# 创建测试数据
df = pd.DataFrame({
'id': [1, 2, 3, 4, 5],
'status': ['open', 'closed', 'open', 'closed', 'pending'],
'value': [100, 200, 150, 300, 250],
'created_at': pd.date_range('2020-01-01', periods=5)
})
# 保存到临时文件
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8') as f:
df.to_csv(f.name, index=False)
temp_file = f.name
try:
# 1. 加载数据
dal = DataAccessLayer.load_from_file(temp_file)
# 2. 生成数据画像
profile = dal.get_profile()
# 3. 验证数据画像
assert profile.row_count == 5
assert profile.column_count == 4
# 4. 验证列类型推断
col_types = {col.name: col.dtype for col in profile.columns}
assert col_types['id'] == 'numeric'
assert col_types['status'] == 'categorical'
assert col_types['value'] == 'numeric'
assert col_types['created_at'] == 'datetime'
# 5. 验证统计信息
value_col = next(col for col in profile.columns if col.name == 'value')
assert 'mean' in value_col.statistics
assert value_col.statistics['mean'] == 200.0
finally:
os.unlink(temp_file)

View File

@@ -1,156 +0,0 @@
"""数据访问层的基于属性的测试。"""
import pytest
import pandas as pd
import numpy as np
from hypothesis import given, strategies as st, settings, HealthCheck
from typing import Dict, Any
from src.data_access import DataAccessLayer
# 生成随机 DataFrame 的策略
@st.composite
def dataframe_strategy(draw):
"""生成随机 DataFrame 用于测试。"""
n_rows = draw(st.integers(min_value=10, max_value=1000))
n_cols = draw(st.integers(min_value=2, max_value=20))
data = {}
for i in range(n_cols):
col_type = draw(st.sampled_from(['int', 'float', 'str', 'datetime']))
col_name = f'col_{i}'
if col_type == 'int':
data[col_name] = draw(st.lists(
st.integers(min_value=-1000, max_value=1000),
min_size=n_rows,
max_size=n_rows
))
elif col_type == 'float':
data[col_name] = draw(st.lists(
st.floats(min_value=-1000.0, max_value=1000.0, allow_nan=False, allow_infinity=False),
min_size=n_rows,
max_size=n_rows
))
elif col_type == 'str':
data[col_name] = draw(st.lists(
st.text(min_size=1, max_size=20, alphabet=st.characters(blacklist_categories=('Cs',))),
min_size=n_rows,
max_size=n_rows
))
else: # datetime
# 生成日期字符串
dates = pd.date_range('2020-01-01', periods=n_rows, freq='D')
data[col_name] = dates.tolist()
return pd.DataFrame(data)
class TestDataAccessProperties:
"""数据访问层的属性测试。"""
# Feature: true-ai-agent, Property 18: 数据访问限制
@given(df=dataframe_strategy())
@settings(max_examples=20, deadline=None, suppress_health_check=[HealthCheck.data_too_large])
def test_property_18_data_access_restriction(self, df):
"""
属性 18数据访问限制
验证需求约束条件5.3
对于任何数据,数据画像应该只包含元数据和统计摘要,
不应该包含完整的原始行级数据。
"""
# 创建数据访问层
dal = DataAccessLayer(df, file_path="test.csv")
# 获取数据画像
profile = dal.get_profile()
# 验证:数据画像不应包含原始数据
# 1. 检查行数和列数是元数据
assert profile.row_count == len(df)
assert profile.column_count == len(df.columns)
# 2. 检查列信息
assert len(profile.columns) == len(df.columns)
for col_info in profile.columns:
# 3. 示例值应该被限制最多5个
assert len(col_info.sample_values) <= 5
# 4. 统计信息应该是聚合数据,不是原始数据
if col_info.dtype == 'numeric':
# 统计信息应该是单个值,不是数组
if col_info.statistics:
for stat_key, stat_value in col_info.statistics.items():
assert not isinstance(stat_value, (list, np.ndarray, pd.Series))
# 应该是标量值或 None
assert stat_value is None or isinstance(stat_value, (int, float))
# 5. 缺失率应该是聚合指标0-1之间的浮点数
assert 0.0 <= col_info.missing_rate <= 1.0
# 6. 唯一值数量应该是聚合指标
assert isinstance(col_info.unique_count, int)
assert col_info.unique_count >= 0
# 7. 验证数据画像的 JSON 序列化不包含大量原始数据
profile_json = profile.to_json()
# JSON 大小应该远小于原始数据
# 原始数据至少有 n_rows * n_cols 个值
# 数据画像应该只有元数据和少量示例
original_data_size = len(df) * len(df.columns)
# 数据画像的大小应该远小于原始数据至少小于10%
assert len(profile_json) < original_data_size * 100 # 粗略估计
@given(df=dataframe_strategy())
@settings(max_examples=10, deadline=None)
def test_data_profile_completeness(self, df):
"""
测试数据画像的完整性。
数据画像应该包含所有必需的元数据字段。
"""
dal = DataAccessLayer(df, file_path="test.csv")
profile = dal.get_profile()
# 验证必需字段存在
assert profile.file_path == "test.csv"
assert profile.row_count > 0
assert profile.column_count > 0
assert len(profile.columns) > 0
assert profile.inferred_type is not None
# 验证每个列信息的完整性
for col_info in profile.columns:
assert col_info.name is not None
assert col_info.dtype in ['numeric', 'categorical', 'datetime', 'text']
assert 0.0 <= col_info.missing_rate <= 1.0
assert col_info.unique_count >= 0
assert isinstance(col_info.sample_values, list)
assert isinstance(col_info.statistics, dict)
@given(df=dataframe_strategy())
@settings(max_examples=10, deadline=None)
def test_column_type_inference(self, df):
"""
测试列类型推断的正确性。
推断的类型应该与实际数据类型一致。
"""
dal = DataAccessLayer(df, file_path="test.csv")
profile = dal.get_profile()
for i, col_info in enumerate(profile.columns):
col_name = col_info.name
actual_dtype = df[col_name].dtype
# 验证类型推断的合理性
if pd.api.types.is_numeric_dtype(actual_dtype):
assert col_info.dtype in ['numeric', 'categorical']
elif pd.api.types.is_datetime64_any_dtype(actual_dtype):
assert col_info.dtype == 'datetime'
elif pd.api.types.is_object_dtype(actual_dtype):
assert col_info.dtype in ['categorical', 'text', 'datetime']

View File

@@ -1,311 +0,0 @@
"""数据理解引擎的单元测试。"""
import pytest
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from src.engines.data_understanding import (
generate_basic_stats,
understand_data,
_infer_column_type,
_infer_data_type,
_identify_key_fields,
_evaluate_data_quality,
_get_sample_values,
_generate_column_statistics
)
from src.models import DataProfile, ColumnInfo
class TestGenerateBasicStats:
"""测试基础统计生成。"""
def test_basic_functionality(self):
"""测试基本功能。"""
df = pd.DataFrame({
'id': [1, 2, 3, 4, 5],
'name': ['A', 'B', 'C', 'D', 'E'],
'value': [10.5, 20.3, 30.1, 40.8, 50.2]
})
stats = generate_basic_stats(df, 'test.csv')
assert stats['file_path'] == 'test.csv'
assert stats['row_count'] == 5
assert stats['column_count'] == 3
assert len(stats['columns']) == 3
def test_empty_dataframe(self):
"""测试空 DataFrame。"""
df = pd.DataFrame()
stats = generate_basic_stats(df, 'empty.csv')
assert stats['row_count'] == 0
assert stats['column_count'] == 0
assert len(stats['columns']) == 0
class TestInferColumnType:
"""测试列类型推断。"""
def test_numeric_column(self):
"""测试数值列。"""
col = pd.Series([1, 2, 3, 4, 5])
dtype = _infer_column_type(col)
assert dtype == 'numeric'
def test_categorical_column(self):
"""测试分类列。"""
col = pd.Series(['A', 'B', 'A', 'C', 'B', 'A', 'A', 'B', 'C', 'A']) # 10个值3个唯一值比例30%
dtype = _infer_column_type(col)
assert dtype == 'categorical'
def test_datetime_column(self):
"""测试日期时间列。"""
col = pd.Series(pd.date_range('2020-01-01', periods=5))
dtype = _infer_column_type(col)
assert dtype == 'datetime'
def test_text_column(self):
"""测试文本列(唯一值多)。"""
col = pd.Series([f'text_{i}' for i in range(100)])
dtype = _infer_column_type(col)
assert dtype == 'text'
class TestInferDataType:
"""测试数据类型推断。"""
def test_ticket_data(self):
"""测试工单数据识别。"""
columns = [
ColumnInfo(name='ticket_id', dtype='text', missing_rate=0.0, unique_count=100),
ColumnInfo(name='status', dtype='categorical', missing_rate=0.0, unique_count=5),
ColumnInfo(name='created_at', dtype='datetime', missing_rate=0.0, unique_count=100),
]
data_type = _infer_data_type(columns)
assert data_type == 'ticket'
def test_sales_data(self):
"""测试销售数据识别。"""
columns = [
ColumnInfo(name='order_id', dtype='text', missing_rate=0.0, unique_count=100),
ColumnInfo(name='product', dtype='categorical', missing_rate=0.0, unique_count=10),
ColumnInfo(name='amount', dtype='numeric', missing_rate=0.0, unique_count=50),
]
data_type = _infer_data_type(columns)
assert data_type == 'sales'
def test_user_data(self):
"""测试用户数据识别。"""
columns = [
ColumnInfo(name='user_id', dtype='text', missing_rate=0.0, unique_count=100),
ColumnInfo(name='name', dtype='text', missing_rate=0.0, unique_count=100),
ColumnInfo(name='email', dtype='text', missing_rate=0.0, unique_count=100),
]
data_type = _infer_data_type(columns)
assert data_type == 'user'
def test_unknown_data(self):
"""测试未知数据类型。"""
columns = [
ColumnInfo(name='col1', dtype='numeric', missing_rate=0.0, unique_count=100),
ColumnInfo(name='col2', dtype='numeric', missing_rate=0.0, unique_count=100),
]
data_type = _infer_data_type(columns)
assert data_type == 'unknown'
class TestIdentifyKeyFields:
"""测试关键字段识别。"""
def test_time_fields(self):
"""测试时间字段识别。"""
columns = [
ColumnInfo(name='created_at', dtype='datetime', missing_rate=0.0, unique_count=100),
ColumnInfo(name='closed_at', dtype='datetime', missing_rate=0.0, unique_count=100),
]
key_fields = _identify_key_fields(columns)
assert 'created_at' in key_fields
assert 'closed_at' in key_fields
assert '创建时间' in key_fields['created_at']
assert '完成时间' in key_fields['closed_at']
def test_status_field(self):
"""测试状态字段识别。"""
columns = [
ColumnInfo(name='status', dtype='categorical', missing_rate=0.0, unique_count=5),
]
key_fields = _identify_key_fields(columns)
assert 'status' in key_fields
assert '状态' in key_fields['status']
def test_id_field(self):
"""测试ID字段识别。"""
columns = [
ColumnInfo(name='ticket_id', dtype='text', missing_rate=0.0, unique_count=100),
]
key_fields = _identify_key_fields(columns)
assert 'ticket_id' in key_fields
assert '标识符' in key_fields['ticket_id']
class TestEvaluateDataQuality:
"""测试数据质量评估。"""
def test_high_quality_data(self):
"""测试高质量数据。"""
columns = [
ColumnInfo(name='col1', dtype='numeric', missing_rate=0.0, unique_count=100),
ColumnInfo(name='col2', dtype='categorical', missing_rate=0.0, unique_count=5),
]
quality_score = _evaluate_data_quality(columns, row_count=100)
assert quality_score >= 80
def test_low_quality_data(self):
"""测试低质量数据(高缺失率)。"""
columns = [
ColumnInfo(name='col1', dtype='numeric', missing_rate=0.8, unique_count=20),
ColumnInfo(name='col2', dtype='categorical', missing_rate=0.9, unique_count=2),
]
quality_score = _evaluate_data_quality(columns, row_count=100)
assert quality_score < 50
def test_empty_data(self):
"""测试空数据。"""
columns = []
quality_score = _evaluate_data_quality(columns, row_count=0)
assert quality_score == 0.0
class TestGetSampleValues:
"""测试示例值获取。"""
def test_basic_functionality(self):
"""测试基本功能。"""
col = pd.Series([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
samples = _get_sample_values(col, max_samples=5)
assert len(samples) <= 5
assert all(isinstance(s, (int, float)) for s in samples)
def test_with_null_values(self):
"""测试包含空值的情况。"""
col = pd.Series([1, 2, None, 4, None, 6])
samples = _get_sample_values(col, max_samples=5)
assert len(samples) <= 4 # 排除了空值
def test_datetime_values(self):
"""测试日期时间值。"""
col = pd.Series(pd.date_range('2020-01-01', periods=5))
samples = _get_sample_values(col, max_samples=3)
assert len(samples) <= 3
assert all(isinstance(s, str) for s in samples)
class TestGenerateColumnStatistics:
"""测试列统计信息生成。"""
def test_numeric_statistics(self):
"""测试数值列统计。"""
col = pd.Series([1, 2, 3, 4, 5])
stats = _generate_column_statistics(col, 'numeric')
assert 'mean' in stats
assert 'median' in stats
assert 'std' in stats
assert 'min' in stats
assert 'max' in stats
assert stats['mean'] == 3.0
assert stats['min'] == 1.0
assert stats['max'] == 5.0
def test_categorical_statistics(self):
"""测试分类列统计。"""
col = pd.Series(['A', 'B', 'A', 'C', 'A'])
stats = _generate_column_statistics(col, 'categorical')
assert 'most_common' in stats
assert 'most_common_count' in stats
assert stats['most_common'] == 'A'
assert stats['most_common_count'] == 3
def test_datetime_statistics(self):
"""测试日期时间列统计。"""
col = pd.Series(pd.date_range('2020-01-01', periods=10))
stats = _generate_column_statistics(col, 'datetime')
assert 'min_date' in stats
assert 'max_date' in stats
assert 'date_range_days' in stats
def test_text_statistics(self):
"""测试文本列统计。"""
col = pd.Series(['hello', 'world', 'test'])
stats = _generate_column_statistics(col, 'text')
assert 'avg_length' in stats
assert 'max_length' in stats
class TestUnderstandData:
"""测试完整的数据理解流程。"""
def test_basic_functionality(self):
"""测试基本功能。"""
df = pd.DataFrame({
'ticket_id': [1, 2, 3, 4, 5],
'status': ['open', 'closed', 'open', 'pending', 'closed'],
'created_at': pd.date_range('2020-01-01', periods=5),
'amount': [100, 200, 150, 300, 250]
})
profile = understand_data('test.csv', data=df)
assert isinstance(profile, DataProfile)
assert profile.row_count == 5
assert profile.column_count == 4
assert len(profile.columns) == 4
assert profile.inferred_type in ['ticket', 'sales', 'user', 'unknown']
assert 0 <= profile.quality_score <= 100
assert len(profile.summary) > 0
def test_with_missing_values(self):
"""测试包含缺失值的数据。"""
df = pd.DataFrame({
'col1': [1, 2, None, 4, 5],
'col2': ['A', None, 'C', 'D', None]
})
profile = understand_data('test.csv', data=df)
assert profile.row_count == 5
# 质量分数应该因为缺失值而降低
assert profile.quality_score < 100

View File

@@ -1,273 +0,0 @@
"""数据理解引擎的基于属性的测试。"""
import pytest
import pandas as pd
import numpy as np
from hypothesis import given, strategies as st, settings, assume
from typing import Dict, Any
from src.engines.data_understanding import (
generate_basic_stats,
understand_data,
_infer_column_type,
_infer_data_type,
_identify_key_fields,
_evaluate_data_quality
)
from src.models import DataProfile, ColumnInfo
# Hypothesis 策略用于生成测试数据
@st.composite
def dataframe_strategy(draw, min_rows=10, max_rows=100, min_cols=2, max_cols=10):
"""生成随机的 DataFrame 实例。"""
n_rows = draw(st.integers(min_value=min_rows, max_value=max_rows))
n_cols = draw(st.integers(min_value=min_cols, max_value=max_cols))
data = {}
for i in range(n_cols):
col_type = draw(st.sampled_from(['int', 'float', 'str', 'datetime']))
col_name = f'col_{i}'
if col_type == 'int':
data[col_name] = draw(st.lists(
st.integers(min_value=-1000, max_value=1000),
min_size=n_rows,
max_size=n_rows
))
elif col_type == 'float':
data[col_name] = draw(st.lists(
st.floats(min_value=-1000.0, max_value=1000.0, allow_nan=False, allow_infinity=False),
min_size=n_rows,
max_size=n_rows
))
elif col_type == 'datetime':
start_date = pd.Timestamp('2020-01-01')
data[col_name] = pd.date_range(start=start_date, periods=n_rows, freq='D')
else: # str
data[col_name] = draw(st.lists(
st.text(min_size=1, max_size=10, alphabet=st.characters(whitelist_categories=('Lu', 'Ll'))),
min_size=n_rows,
max_size=n_rows
))
return pd.DataFrame(data)
# Feature: true-ai-agent, Property 1: 数据类型识别
@given(df=dataframe_strategy(min_rows=10, max_rows=100))
@settings(max_examples=20, deadline=None)
def test_data_type_inference(df):
"""
属性 1对于任何有效的 CSV 文件,数据理解引擎应该能够推断出数据的业务类型
(如工单、销售、用户等),并且推断结果应该基于列名、数据类型和值分布的分析。
验证需求场景1验收.1
"""
# 执行数据理解
profile = understand_data(file_path='test.csv', data=df)
# 验证:应该有推断的类型
assert profile.inferred_type is not None, "推断的数据类型不应为 None"
assert profile.inferred_type in ['ticket', 'sales', 'user', 'unknown'], \
f"推断的数据类型应该是预定义的类型之一,但得到:{profile.inferred_type}"
# 验证:推断应该基于数据特征
# 至少应该识别出一些关键字段或生成摘要
assert len(profile.summary) > 0, "应该生成数据摘要"
# Feature: true-ai-agent, Property 2: 数据画像完整性
@given(df=dataframe_strategy(min_rows=5, max_rows=50))
@settings(max_examples=20, deadline=None)
def test_data_profile_completeness(df):
"""
属性 2对于任何有效的 CSV 文件,生成的数据画像应该包含所有必需字段
(行数、列数、列信息、推断类型、关键字段、质量分数),并且列信息应该
包含每列的名称、类型、缺失率和统计信息。
验证需求FR-1.2, FR-1.3, FR-1.4
"""
# 执行数据理解
profile = understand_data(file_path='test.csv', data=df)
# 验证:数据画像应该包含所有必需字段
assert hasattr(profile, 'file_path'), "数据画像缺少 file_path 字段"
assert hasattr(profile, 'row_count'), "数据画像缺少 row_count 字段"
assert hasattr(profile, 'column_count'), "数据画像缺少 column_count 字段"
assert hasattr(profile, 'columns'), "数据画像缺少 columns 字段"
assert hasattr(profile, 'inferred_type'), "数据画像缺少 inferred_type 字段"
assert hasattr(profile, 'key_fields'), "数据画像缺少 key_fields 字段"
assert hasattr(profile, 'quality_score'), "数据画像缺少 quality_score 字段"
assert hasattr(profile, 'summary'), "数据画像缺少 summary 字段"
# 验证:行数和列数应该正确
assert profile.row_count == len(df), f"行数不匹配:期望 {len(df)},得到 {profile.row_count}"
assert profile.column_count == len(df.columns), \
f"列数不匹配:期望 {len(df.columns)},得到 {profile.column_count}"
# 验证:列信息应该完整
assert len(profile.columns) == len(df.columns), \
f"列信息数量不匹配:期望 {len(df.columns)},得到 {len(profile.columns)}"
for col_info in profile.columns:
# 验证:每列应该有名称、类型、缺失率
assert hasattr(col_info, 'name'), "列信息缺少 name 字段"
assert hasattr(col_info, 'dtype'), "列信息缺少 dtype 字段"
assert hasattr(col_info, 'missing_rate'), "列信息缺少 missing_rate 字段"
assert hasattr(col_info, 'unique_count'), "列信息缺少 unique_count 字段"
assert hasattr(col_info, 'statistics'), "列信息缺少 statistics 字段"
# 验证:数据类型应该是预定义的类型之一
assert col_info.dtype in ['numeric', 'categorical', 'datetime', 'text'], \
f"{col_info.name} 的数据类型应该是预定义的类型之一,但得到:{col_info.dtype}"
# 验证:缺失率应该在 0-1 之间
assert 0.0 <= col_info.missing_rate <= 1.0, \
f"{col_info.name} 的缺失率应该在 0-1 之间,但得到:{col_info.missing_rate}"
# 验证:唯一值数量应该合理
assert col_info.unique_count >= 0, \
f"{col_info.name} 的唯一值数量应该非负,但得到:{col_info.unique_count}"
assert col_info.unique_count <= len(df), \
f"{col_info.name} 的唯一值数量不应超过总行数"
# 验证:质量分数应该在 0-100 之间
assert 0.0 <= profile.quality_score <= 100.0, \
f"质量分数应该在 0-100 之间,但得到:{profile.quality_score}"
# 额外测试:验证列类型推断的正确性
@given(
numeric_data=st.lists(st.floats(min_value=-1000, max_value=1000, allow_nan=False, allow_infinity=False),
min_size=10, max_size=100),
categorical_data=st.lists(st.sampled_from(['A', 'B', 'C', 'D']), min_size=10, max_size=100)
)
@settings(max_examples=10)
def test_column_type_inference(numeric_data, categorical_data):
"""测试列类型推断的正确性。"""
# 测试数值列
numeric_series = pd.Series(numeric_data)
numeric_type = _infer_column_type(numeric_series)
assert numeric_type == 'numeric', f"数值列应该被识别为 'numeric',但得到:{numeric_type}"
# 测试分类列
categorical_series = pd.Series(categorical_data)
categorical_type = _infer_column_type(categorical_series)
assert categorical_type == 'categorical', \
f"分类列应该被识别为 'categorical',但得到:{categorical_type}"
# 额外测试:验证数据质量评估的合理性
@given(
missing_rate=st.floats(min_value=0.0, max_value=1.0),
n_cols=st.integers(min_value=1, max_value=10)
)
@settings(max_examples=10)
def test_data_quality_evaluation(missing_rate, n_cols):
"""测试数据质量评估的合理性。"""
# 创建具有指定缺失率的列信息
columns = []
for i in range(n_cols):
col_info = ColumnInfo(
name=f'col_{i}',
dtype='numeric',
missing_rate=missing_rate,
unique_count=100,
sample_values=[1, 2, 3],
statistics={}
)
columns.append(col_info)
# 评估数据质量
quality_score = _evaluate_data_quality(columns, row_count=100)
# 验证:质量分数应该在 0-100 之间
assert 0.0 <= quality_score <= 100.0, \
f"质量分数应该在 0-100 之间,但得到:{quality_score}"
# 验证:缺失率越高,质量分数应该越低
if missing_rate > 0.5:
assert quality_score < 70, \
f"高缺失率({missing_rate})应该导致较低的质量分数,但得到:{quality_score}"
# 额外测试:验证基础统计生成的完整性
@given(df=dataframe_strategy(min_rows=5, max_rows=50))
@settings(max_examples=10, deadline=None)
def test_basic_stats_generation(df):
"""测试基础统计生成的完整性。"""
# 生成基础统计
stats = generate_basic_stats(df, file_path='test.csv')
# 验证:应该包含必需字段
assert 'file_path' in stats, "基础统计缺少 file_path 字段"
assert 'row_count' in stats, "基础统计缺少 row_count 字段"
assert 'column_count' in stats, "基础统计缺少 column_count 字段"
assert 'columns' in stats, "基础统计缺少 columns 字段"
# 验证:统计信息应该准确
assert stats['row_count'] == len(df), "行数统计不准确"
assert stats['column_count'] == len(df.columns), "列数统计不准确"
assert len(stats['columns']) == len(df.columns), "列信息数量不匹配"
# 额外测试:验证关键字段识别
def test_key_field_identification():
"""测试关键字段识别功能。"""
# 创建包含典型字段名的列信息
columns = [
ColumnInfo(name='created_at', dtype='datetime', missing_rate=0.0, unique_count=100),
ColumnInfo(name='status', dtype='categorical', missing_rate=0.0, unique_count=5),
ColumnInfo(name='ticket_id', dtype='text', missing_rate=0.0, unique_count=100),
ColumnInfo(name='amount', dtype='numeric', missing_rate=0.0, unique_count=50),
]
# 识别关键字段
key_fields = _identify_key_fields(columns)
# 验证:应该识别出时间字段
assert 'created_at' in key_fields, "应该识别出 created_at 为关键字段"
# 验证:应该识别出状态字段
assert 'status' in key_fields, "应该识别出 status 为关键字段"
# 验证应该识别出ID字段
assert 'ticket_id' in key_fields, "应该识别出 ticket_id 为关键字段"
# 验证:应该识别出金额字段
assert 'amount' in key_fields, "应该识别出 amount 为关键字段"
# 额外测试:验证数据类型推断
def test_data_type_inference_with_keywords():
"""测试基于关键词的数据类型推断。"""
# 工单数据
ticket_columns = [
ColumnInfo(name='ticket_id', dtype='text', missing_rate=0.0, unique_count=100),
ColumnInfo(name='status', dtype='categorical', missing_rate=0.0, unique_count=5),
ColumnInfo(name='created_at', dtype='datetime', missing_rate=0.0, unique_count=100),
]
ticket_type = _infer_data_type(ticket_columns)
assert ticket_type == 'ticket', f"应该识别为工单数据,但得到:{ticket_type}"
# 销售数据
sales_columns = [
ColumnInfo(name='order_id', dtype='text', missing_rate=0.0, unique_count=100),
ColumnInfo(name='product', dtype='categorical', missing_rate=0.0, unique_count=10),
ColumnInfo(name='amount', dtype='numeric', missing_rate=0.0, unique_count=50),
ColumnInfo(name='sales_date', dtype='datetime', missing_rate=0.0, unique_count=100),
]
sales_type = _infer_data_type(sales_columns)
assert sales_type == 'sales', f"应该识别为销售数据,但得到:{sales_type}"
# 用户数据
user_columns = [
ColumnInfo(name='user_id', dtype='text', missing_rate=0.0, unique_count=100),
ColumnInfo(name='name', dtype='text', missing_rate=0.0, unique_count=100),
ColumnInfo(name='email', dtype='text', missing_rate=0.0, unique_count=100),
ColumnInfo(name='age', dtype='numeric', missing_rate=0.0, unique_count=50),
]
user_type = _infer_data_type(user_columns)
assert user_type == 'user', f"应该识别为用户数据,但得到:{user_type}"

View File

@@ -1,255 +0,0 @@
"""环境变量加载器的单元测试。"""
import os
import pytest
from pathlib import Path
from unittest.mock import patch
from src.env_loader import (
load_env_file,
load_env_with_fallback,
get_env,
get_env_bool,
get_env_int,
get_env_float,
validate_required_env_vars
)
class TestLoadEnvFile:
"""测试加载 .env 文件。"""
def test_load_env_file_success(self, tmp_path):
"""测试成功加载 .env 文件。"""
env_file = tmp_path / ".env"
env_file.write_text("""
# This is a comment
KEY1=value1
KEY2="value2"
KEY3='value3'
KEY4=value with spaces
# Another comment
KEY5=123
""", encoding='utf-8')
# 清空环境变量
with patch.dict(os.environ, {}, clear=True):
result = load_env_file(str(env_file))
assert result is True
assert os.getenv("KEY1") == "value1"
assert os.getenv("KEY2") == "value2"
assert os.getenv("KEY3") == "value3"
assert os.getenv("KEY4") == "value with spaces"
assert os.getenv("KEY5") == "123"
def test_load_env_file_not_found(self):
"""测试加载不存在的 .env 文件。"""
result = load_env_file("nonexistent.env")
assert result is False
def test_load_env_file_skip_existing(self, tmp_path):
"""测试跳过已存在的环境变量。"""
env_file = tmp_path / ".env"
env_file.write_text("KEY1=from_file\nKEY2=from_file")
# 设置一个已存在的环境变量
with patch.dict(os.environ, {"KEY1": "from_env"}, clear=True):
load_env_file(str(env_file))
# KEY1 应该保持原值(环境变量优先)
assert os.getenv("KEY1") == "from_env"
# KEY2 应该从文件加载
assert os.getenv("KEY2") == "from_file"
def test_load_env_file_skip_invalid_lines(self, tmp_path):
"""测试跳过无效行。"""
env_file = tmp_path / ".env"
env_file.write_text("""
VALID_KEY=valid_value
invalid line without equals
ANOTHER_VALID=another_value
""")
with patch.dict(os.environ, {}, clear=True):
result = load_env_file(str(env_file))
assert result is True
assert os.getenv("VALID_KEY") == "valid_value"
assert os.getenv("ANOTHER_VALID") == "another_value"
def test_load_env_file_empty_lines(self, tmp_path):
"""测试处理空行。"""
env_file = tmp_path / ".env"
env_file.write_text("""
KEY1=value1
KEY2=value2
KEY3=value3
""")
with patch.dict(os.environ, {}, clear=True):
result = load_env_file(str(env_file))
assert result is True
assert os.getenv("KEY1") == "value1"
assert os.getenv("KEY2") == "value2"
assert os.getenv("KEY3") == "value3"
class TestLoadEnvWithFallback:
"""测试按优先级加载多个 .env 文件。"""
def test_load_multiple_files(self, tmp_path):
"""测试加载多个文件。"""
env_file1 = tmp_path / ".env.local"
env_file1.write_text("KEY1=local\nKEY2=local")
env_file2 = tmp_path / ".env"
env_file2.write_text("KEY1=default\nKEY3=default")
with patch.dict(os.environ, {}, clear=True):
# 切换到临时目录
original_dir = os.getcwd()
os.chdir(tmp_path)
try:
result = load_env_with_fallback([".env.local", ".env"])
assert result is True
# KEY1 应该来自 .env.local优先级更高
assert os.getenv("KEY1") == "local"
# KEY2 应该来自 .env.local
assert os.getenv("KEY2") == "local"
# KEY3 应该来自 .env
assert os.getenv("KEY3") == "default"
finally:
os.chdir(original_dir)
def test_load_no_files_found(self):
"""测试没有找到任何文件。"""
result = load_env_with_fallback(["nonexistent1.env", "nonexistent2.env"])
assert result is False
class TestGetEnv:
"""测试获取环境变量。"""
def test_get_env_exists(self):
"""测试获取存在的环境变量。"""
with patch.dict(os.environ, {"TEST_KEY": "test_value"}):
assert get_env("TEST_KEY") == "test_value"
def test_get_env_not_exists(self):
"""测试获取不存在的环境变量。"""
with patch.dict(os.environ, {}, clear=True):
assert get_env("NONEXISTENT_KEY") is None
def test_get_env_with_default(self):
"""测试使用默认值。"""
with patch.dict(os.environ, {}, clear=True):
assert get_env("NONEXISTENT_KEY", "default") == "default"
class TestGetEnvBool:
"""测试获取布尔类型环境变量。"""
def test_get_env_bool_true_values(self):
"""测试 True 值。"""
true_values = ["true", "True", "TRUE", "yes", "Yes", "YES", "1", "on", "On", "ON"]
for value in true_values:
with patch.dict(os.environ, {"TEST_BOOL": value}):
assert get_env_bool("TEST_BOOL") is True
def test_get_env_bool_false_values(self):
"""测试 False 值。"""
false_values = ["false", "False", "FALSE", "no", "No", "NO", "0", "off", "Off", "OFF"]
for value in false_values:
with patch.dict(os.environ, {"TEST_BOOL": value}):
assert get_env_bool("TEST_BOOL") is False
def test_get_env_bool_default(self):
"""测试默认值。"""
with patch.dict(os.environ, {}, clear=True):
assert get_env_bool("NONEXISTENT_BOOL") is False
assert get_env_bool("NONEXISTENT_BOOL", True) is True
class TestGetEnvInt:
"""测试获取整数类型环境变量。"""
def test_get_env_int_valid(self):
"""测试有效的整数。"""
with patch.dict(os.environ, {"TEST_INT": "123"}):
assert get_env_int("TEST_INT") == 123
def test_get_env_int_negative(self):
"""测试负整数。"""
with patch.dict(os.environ, {"TEST_INT": "-456"}):
assert get_env_int("TEST_INT") == -456
def test_get_env_int_invalid(self):
"""测试无效的整数。"""
with patch.dict(os.environ, {"TEST_INT": "not_a_number"}):
assert get_env_int("TEST_INT") == 0
assert get_env_int("TEST_INT", 999) == 999
def test_get_env_int_default(self):
"""测试默认值。"""
with patch.dict(os.environ, {}, clear=True):
assert get_env_int("NONEXISTENT_INT") == 0
assert get_env_int("NONEXISTENT_INT", 42) == 42
class TestGetEnvFloat:
"""测试获取浮点数类型环境变量。"""
def test_get_env_float_valid(self):
"""测试有效的浮点数。"""
with patch.dict(os.environ, {"TEST_FLOAT": "3.14"}):
assert get_env_float("TEST_FLOAT") == 3.14
def test_get_env_float_negative(self):
"""测试负浮点数。"""
with patch.dict(os.environ, {"TEST_FLOAT": "-2.5"}):
assert get_env_float("TEST_FLOAT") == -2.5
def test_get_env_float_invalid(self):
"""测试无效的浮点数。"""
with patch.dict(os.environ, {"TEST_FLOAT": "not_a_number"}):
assert get_env_float("TEST_FLOAT") == 0.0
assert get_env_float("TEST_FLOAT", 9.99) == 9.99
def test_get_env_float_default(self):
"""测试默认值。"""
with patch.dict(os.environ, {}, clear=True):
assert get_env_float("NONEXISTENT_FLOAT") == 0.0
assert get_env_float("NONEXISTENT_FLOAT", 1.5) == 1.5
class TestValidateRequiredEnvVars:
"""测试验证必需的环境变量。"""
def test_validate_all_present(self):
"""测试所有必需的环境变量都存在。"""
with patch.dict(os.environ, {"KEY1": "value1", "KEY2": "value2", "KEY3": "value3"}):
assert validate_required_env_vars(["KEY1", "KEY2", "KEY3"]) is True
def test_validate_some_missing(self):
"""测试部分环境变量缺失。"""
with patch.dict(os.environ, {"KEY1": "value1"}, clear=True):
assert validate_required_env_vars(["KEY1", "KEY2", "KEY3"]) is False
def test_validate_all_missing(self):
"""测试所有环境变量都缺失。"""
with patch.dict(os.environ, {}, clear=True):
assert validate_required_env_vars(["KEY1", "KEY2"]) is False
def test_validate_empty_list(self):
"""测试空列表。"""
assert validate_required_env_vars([]) is True

View File

@@ -1,426 +0,0 @@
"""单元测试:错误处理机制。"""
import pytest
import pandas as pd
import time
from pathlib import Path
from unittest.mock import Mock, patch, MagicMock
import tempfile
import os
from src.error_handling import (
load_data_with_retry,
call_llm_with_fallback,
execute_tool_safely,
execute_task_with_recovery,
validate_tool_params,
validate_tool_result,
DataLoadError,
AICallError,
ToolExecutionError
)
class TestLoadDataWithRetry:
"""测试数据加载错误处理。"""
def test_load_valid_csv(self, tmp_path):
"""测试加载有效的 CSV 文件。"""
# 创建测试文件
csv_file = tmp_path / "test.csv"
df = pd.DataFrame({
'col1': [1, 2, 3],
'col2': ['a', 'b', 'c']
})
df.to_csv(csv_file, index=False)
# 加载数据
result = load_data_with_retry(str(csv_file))
assert len(result) == 3
assert len(result.columns) == 2
assert list(result.columns) == ['col1', 'col2']
def test_load_gbk_encoded_file(self, tmp_path):
"""测试加载 GBK 编码的文件。"""
# 创建 GBK 编码的文件
csv_file = tmp_path / "test_gbk.csv"
df = pd.DataFrame({
'列1': [1, 2, 3],
'列2': ['中文', '测试', '数据']
})
df.to_csv(csv_file, index=False, encoding='gbk')
# 加载数据
result = load_data_with_retry(str(csv_file))
assert len(result) == 3
assert '列1' in result.columns
assert '列2' in result.columns
def test_load_file_not_exists(self):
"""测试文件不存在的情况。"""
with pytest.raises(DataLoadError, match="文件不存在"):
load_data_with_retry("nonexistent.csv")
def test_load_empty_file(self, tmp_path):
"""测试空文件的处理。"""
# 创建空文件
csv_file = tmp_path / "empty.csv"
csv_file.touch()
with pytest.raises(DataLoadError, match="文件为空"):
load_data_with_retry(str(csv_file))
def test_load_large_file_sampling(self, tmp_path):
"""测试大文件采样。"""
# 创建大文件(模拟)
csv_file = tmp_path / "large.csv"
df = pd.DataFrame({
'col1': range(2000000),
'col2': range(2000000)
})
# 只保存前 1500000 行以加快测试
df.head(1500000).to_csv(csv_file, index=False)
# 加载数据(应该采样到 1000000 行)
result = load_data_with_retry(str(csv_file), sample_size=1000000)
assert len(result) == 1000000
def test_load_different_separator(self, tmp_path):
"""测试不同分隔符的文件。"""
# 创建使用分号分隔的文件
csv_file = tmp_path / "semicolon.csv"
with open(csv_file, 'w') as f:
f.write("col1;col2\n")
f.write("1;a\n")
f.write("2;b\n")
# 加载数据
result = load_data_with_retry(str(csv_file))
assert len(result) == 2
assert len(result.columns) == 2
class TestCallLLMWithFallback:
"""测试 AI 调用错误处理。"""
def test_successful_call(self):
"""测试成功的 AI 调用。"""
mock_func = Mock(return_value={'result': 'success'})
result = call_llm_with_fallback(mock_func, prompt="test")
assert result == {'result': 'success'}
assert mock_func.call_count == 1
def test_retry_on_timeout(self):
"""测试超时重试机制。"""
mock_func = Mock(side_effect=[
TimeoutError("timeout"),
TimeoutError("timeout"),
{'result': 'success'}
])
result = call_llm_with_fallback(mock_func, max_retries=3, prompt="test")
assert result == {'result': 'success'}
assert mock_func.call_count == 3
def test_exponential_backoff(self):
"""测试指数退避。"""
mock_func = Mock(side_effect=[
Exception("error"),
{'result': 'success'}
])
start_time = time.time()
result = call_llm_with_fallback(mock_func, max_retries=3, prompt="test")
elapsed = time.time() - start_time
# 应该等待至少 1 秒2^0
assert elapsed >= 1.0
assert result == {'result': 'success'}
def test_fallback_on_failure(self):
"""测试降级策略。"""
mock_func = Mock(side_effect=Exception("error"))
fallback_func = Mock(return_value={'result': 'fallback'})
result = call_llm_with_fallback(
mock_func,
fallback_func=fallback_func,
max_retries=2,
prompt="test"
)
assert result == {'result': 'fallback'}
assert mock_func.call_count == 2
assert fallback_func.call_count == 1
def test_no_fallback_raises_error(self):
"""测试无降级策略时抛出错误。"""
mock_func = Mock(side_effect=Exception("error"))
with pytest.raises(AICallError, match="AI 调用失败"):
call_llm_with_fallback(mock_func, max_retries=2, prompt="test")
def test_fallback_also_fails(self):
"""测试降级策略也失败的情况。"""
mock_func = Mock(side_effect=Exception("error"))
fallback_func = Mock(side_effect=Exception("fallback error"))
with pytest.raises(AICallError, match="AI 调用和降级策略都失败"):
call_llm_with_fallback(
mock_func,
fallback_func=fallback_func,
max_retries=2,
prompt="test"
)
class TestExecuteToolSafely:
"""测试工具执行错误处理。"""
def test_successful_execution(self):
"""测试成功的工具执行。"""
mock_tool = Mock()
mock_tool.name = "test_tool"
mock_tool.parameters = {'required': [], 'properties': {}}
mock_tool.execute = Mock(return_value={'data': 'result'})
df = pd.DataFrame({'col1': [1, 2, 3]})
result = execute_tool_safely(mock_tool, df)
assert result['success'] is True
assert result['data'] == {'data': 'result'}
assert result['tool'] == 'test_tool'
def test_missing_execute_method(self):
"""测试工具缺少 execute 方法。"""
mock_tool = Mock(spec=[])
mock_tool.name = "bad_tool"
df = pd.DataFrame({'col1': [1, 2, 3]})
result = execute_tool_safely(mock_tool, df)
assert result['success'] is False
assert 'execute 方法' in result['error']
def test_parameter_validation_failure(self):
"""测试参数验证失败。"""
mock_tool = Mock()
mock_tool.name = "test_tool"
mock_tool.parameters = {
'required': ['column'],
'properties': {
'column': {'type': 'string'}
}
}
mock_tool.execute = Mock(return_value={'data': 'result'})
df = pd.DataFrame({'col1': [1, 2, 3]})
# 缺少必需参数
result = execute_tool_safely(mock_tool, df)
assert result['success'] is False
assert '参数验证失败' in result['error']
def test_empty_data(self):
"""测试空数据。"""
mock_tool = Mock()
mock_tool.name = "test_tool"
mock_tool.parameters = {'required': [], 'properties': {}}
df = pd.DataFrame()
result = execute_tool_safely(mock_tool, df)
assert result['success'] is False
assert '数据为空' in result['error']
def test_execution_exception(self):
"""测试执行异常。"""
mock_tool = Mock()
mock_tool.name = "test_tool"
mock_tool.parameters = {'required': [], 'properties': {}}
mock_tool.execute = Mock(side_effect=Exception("execution error"))
df = pd.DataFrame({'col1': [1, 2, 3]})
result = execute_tool_safely(mock_tool, df)
assert result['success'] is False
assert 'execution error' in result['error']
class TestValidateToolParams:
"""测试工具参数验证。"""
def test_valid_params(self):
"""测试有效参数。"""
mock_tool = Mock()
mock_tool.parameters = {
'required': ['column'],
'properties': {
'column': {'type': 'string'}
}
}
result = validate_tool_params(mock_tool, {'column': 'col1'})
assert result['valid'] is True
def test_missing_required_param(self):
"""测试缺少必需参数。"""
mock_tool = Mock()
mock_tool.parameters = {
'required': ['column'],
'properties': {}
}
result = validate_tool_params(mock_tool, {})
assert result['valid'] is False
assert '缺少必需参数' in result['error']
def test_wrong_param_type(self):
"""测试参数类型错误。"""
mock_tool = Mock()
mock_tool.parameters = {
'required': [],
'properties': {
'column': {'type': 'string'}
}
}
result = validate_tool_params(mock_tool, {'column': 123})
assert result['valid'] is False
assert '应为字符串类型' in result['error']
class TestValidateToolResult:
"""测试工具结果验证。"""
def test_valid_result(self):
"""测试有效结果。"""
result = validate_tool_result({'data': 'test'})
assert result['valid'] is True
def test_none_result(self):
"""测试 None 结果。"""
result = validate_tool_result(None)
assert result['valid'] is False
assert 'None' in result['error']
def test_wrong_type_result(self):
"""测试错误类型结果。"""
result = validate_tool_result("string result")
assert result['valid'] is False
assert '类型错误' in result['error']
class TestExecuteTaskWithRecovery:
"""测试任务执行错误处理。"""
def test_successful_execution(self):
"""测试成功的任务执行。"""
mock_task = Mock()
mock_task.id = "task1"
mock_task.name = "Test Task"
mock_task.dependencies = []
mock_plan = Mock()
mock_plan.tasks = [mock_task]
mock_execute = Mock(return_value=Mock(success=True))
result = execute_task_with_recovery(mock_task, mock_plan, mock_execute)
assert mock_task.status == 'completed'
assert mock_execute.call_count == 1
def test_skip_on_missing_dependency(self):
"""测试依赖任务不存在时跳过。"""
mock_task = Mock()
mock_task.id = "task2"
mock_task.name = "Test Task"
mock_task.dependencies = ["task1"]
mock_plan = Mock()
mock_plan.tasks = [mock_task]
mock_execute = Mock()
result = execute_task_with_recovery(mock_task, mock_plan, mock_execute)
assert mock_task.status == 'skipped'
assert mock_execute.call_count == 0
def test_skip_on_failed_dependency(self):
"""测试依赖任务失败时跳过。"""
mock_dep_task = Mock()
mock_dep_task.id = "task1"
mock_dep_task.status = 'failed'
mock_task = Mock()
mock_task.id = "task2"
mock_task.name = "Test Task"
mock_task.dependencies = ["task1"]
mock_plan = Mock()
mock_plan.tasks = [mock_dep_task, mock_task]
mock_execute = Mock()
result = execute_task_with_recovery(mock_task, mock_plan, mock_execute)
assert mock_task.status == 'skipped'
assert mock_execute.call_count == 0
def test_mark_failed_on_exception(self):
"""测试执行异常时标记失败。"""
mock_task = Mock()
mock_task.id = "task1"
mock_task.name = "Test Task"
mock_task.dependencies = []
mock_plan = Mock()
mock_plan.tasks = [mock_task]
mock_execute = Mock(side_effect=Exception("execution error"))
result = execute_task_with_recovery(mock_task, mock_plan, mock_execute)
assert mock_task.status == 'failed'
def test_continue_on_task_failure(self):
"""测试单个任务失败不影响其他任务。"""
mock_task1 = Mock()
mock_task1.id = "task1"
mock_task1.name = "Task 1"
mock_task1.dependencies = []
mock_task2 = Mock()
mock_task2.id = "task2"
mock_task2.name = "Task 2"
mock_task2.dependencies = []
mock_plan = Mock()
mock_plan.tasks = [mock_task1, mock_task2]
# 第一个任务失败
mock_execute = Mock(side_effect=Exception("error"))
result1 = execute_task_with_recovery(mock_task1, mock_plan, mock_execute)
assert mock_task1.status == 'failed'
# 第二个任务应该可以继续执行
mock_execute2 = Mock(return_value=Mock(success=True))
result2 = execute_task_with_recovery(mock_task2, mock_plan, mock_execute2)
assert mock_task2.status == 'completed'

View File

@@ -1,404 +0,0 @@
"""集成测试 - 测试端到端分析流程。"""
import pytest
import pandas as pd
from pathlib import Path
import tempfile
import shutil
from src.main import run_analysis, AnalysisOrchestrator
from src.data_access import DataAccessLayer
@pytest.fixture
def temp_output_dir():
"""创建临时输出目录。"""
temp_dir = tempfile.mkdtemp()
yield temp_dir
# 清理
shutil.rmtree(temp_dir, ignore_errors=True)
@pytest.fixture
def sample_ticket_data(tmp_path):
"""创建示例工单数据。"""
data = pd.DataFrame({
'ticket_id': range(1, 101),
'status': ['open'] * 50 + ['closed'] * 30 + ['pending'] * 20,
'priority': ['high'] * 30 + ['medium'] * 40 + ['low'] * 30,
'created_at': pd.date_range('2024-01-01', periods=100, freq='D'),
'closed_at': [None] * 50 + list(pd.date_range('2024-02-01', periods=50, freq='D')),
'category': ['bug'] * 40 + ['feature'] * 30 + ['support'] * 30,
'duration_hours': [24] * 30 + [48] * 40 + [12] * 30
})
file_path = tmp_path / "tickets.csv"
data.to_csv(file_path, index=False)
return str(file_path)
@pytest.fixture
def sample_sales_data(tmp_path):
"""创建示例销售数据。"""
data = pd.DataFrame({
'order_id': range(1, 101),
'product': ['A'] * 40 + ['B'] * 30 + ['C'] * 30,
'quantity': [1, 2, 3, 4, 5] * 20,
'price': [100.0, 200.0, 150.0, 300.0, 250.0] * 20,
'date': pd.date_range('2024-01-01', periods=100, freq='D'),
'region': ['North'] * 30 + ['South'] * 40 + ['East'] * 30
})
file_path = tmp_path / "sales.csv"
data.to_csv(file_path, index=False)
return str(file_path)
@pytest.fixture
def sample_template(tmp_path):
"""创建示例模板。"""
template_content = """# 工单分析模板
## 1. 概述
- 总工单数
- 状态分布
## 2. 优先级分析
- 优先级分布
- 高优先级工单处理情况
## 3. 时间分析
- 创建趋势
- 处理时长分析
## 4. 分类分析
- 类别分布
- 各类别处理情况
"""
file_path = tmp_path / "template.md"
file_path.write_text(template_content, encoding='utf-8')
return str(file_path)
class TestEndToEndAnalysis:
"""端到端分析流程测试。"""
def test_complete_analysis_without_requirement(self, sample_ticket_data, temp_output_dir):
"""
测试完全自主分析(无用户需求)。
验证:
- 能够加载数据
- 能够推断数据类型
- 能够生成分析计划
- 能够执行任务
- 能够生成报告
"""
# 运行分析
result = run_analysis(
data_file=sample_ticket_data,
user_requirement=None, # 无用户需求
output_dir=temp_output_dir
)
# 验证结果
assert result['success'] is True, f"分析失败: {result.get('error')}"
assert 'data_type' in result
assert result['objectives_count'] > 0
assert result['tasks_count'] > 0
assert result['results_count'] > 0
# 验证报告文件存在
report_path = Path(result['report_path'])
assert report_path.exists()
assert report_path.stat().st_size > 0
# 验证报告内容
report_content = report_path.read_text(encoding='utf-8')
assert len(report_content) > 0
assert '分析报告' in report_content or '报告' in report_content
def test_analysis_with_requirement(self, sample_ticket_data, temp_output_dir):
"""
测试指定需求的分析。
验证:
- 能够理解用户需求
- 生成的分析目标与需求相关
- 报告聚焦于用户需求
"""
# 运行分析
result = run_analysis(
data_file=sample_ticket_data,
user_requirement="分析工单的健康度和处理效率",
output_dir=temp_output_dir
)
# 验证结果
assert result['success'] is True, f"分析失败: {result.get('error')}"
assert result['objectives_count'] > 0
# 验证报告内容与需求相关
report_path = Path(result['report_path'])
report_content = report_path.read_text(encoding='utf-8')
# 报告应该包含与需求相关的关键词
assert any(keyword in report_content for keyword in ['健康', '效率', '处理'])
def test_template_based_analysis(self, sample_ticket_data, sample_template, temp_output_dir):
"""
测试基于模板的分析。
验证:
- 能够解析模板
- 报告结构遵循模板
- 如果数据不满足模板要求,能够灵活调整
"""
# 运行分析
result = run_analysis(
data_file=sample_ticket_data,
template_file=sample_template,
output_dir=temp_output_dir
)
# 验证结果
assert result['success'] is True, f"分析失败: {result.get('error')}"
# 验证报告结构
report_path = Path(result['report_path'])
report_content = report_path.read_text(encoding='utf-8')
# 报告应该包含模板中的章节
assert '概述' in report_content or '总工单数' in report_content
assert '优先级' in report_content or '分类' in report_content
def test_different_data_types(self, sample_sales_data, temp_output_dir):
"""
测试不同类型的数据。
验证:
- 能够识别不同的数据类型
- 能够为不同数据类型生成合适的分析
"""
# 运行分析
result = run_analysis(
data_file=sample_sales_data,
output_dir=temp_output_dir
)
# 验证结果
assert result['success'] is True, f"分析失败: {result.get('error')}"
assert 'data_type' in result
assert result['tasks_count'] > 0
class TestErrorRecovery:
"""错误恢复测试。"""
def test_invalid_file_path(self, temp_output_dir):
"""
测试无效文件路径的处理。
验证:
- 能够捕获文件不存在错误
- 返回有意义的错误信息
"""
# 运行分析
result = run_analysis(
data_file="nonexistent_file.csv",
output_dir=temp_output_dir
)
# 验证结果
assert result['success'] is False
assert 'error' in result
assert len(result['error']) > 0
def test_empty_file(self, tmp_path, temp_output_dir):
"""
测试空文件的处理。
验证:
- 能够检测空文件
- 返回有意义的错误信息
"""
# 创建空文件
empty_file = tmp_path / "empty.csv"
empty_file.write_text("", encoding='utf-8')
# 运行分析
result = run_analysis(
data_file=str(empty_file),
output_dir=temp_output_dir
)
# 验证结果
assert result['success'] is False
assert 'error' in result
def test_malformed_csv(self, tmp_path, temp_output_dir):
"""
测试格式错误的 CSV 文件。
验证:
- 能够处理格式错误
- 尝试多种解析策略
"""
# 创建格式错误的 CSV
malformed_file = tmp_path / "malformed.csv"
malformed_file.write_text("col1,col2\nvalue1\nvalue2,value3,value4", encoding='utf-8')
# 运行分析(可能成功也可能失败,取决于错误处理策略)
result = run_analysis(
data_file=str(malformed_file),
output_dir=temp_output_dir
)
# 验证至少有结果返回
assert 'success' in result
assert 'elapsed_time' in result
class TestOrchestrator:
"""编排器测试。"""
def test_orchestrator_initialization(self, sample_ticket_data, temp_output_dir):
"""
测试编排器初始化。
验证:
- 能够正确初始化
- 输出目录被创建
"""
orchestrator = AnalysisOrchestrator(
data_file=sample_ticket_data,
output_dir=temp_output_dir
)
assert orchestrator.data_file == sample_ticket_data
assert orchestrator.output_dir.exists()
assert orchestrator.output_dir.is_dir()
def test_orchestrator_stages(self, sample_ticket_data, temp_output_dir):
"""
测试编排器各阶段执行。
验证:
- 各阶段按顺序执行
- 每个阶段产生预期输出
"""
orchestrator = AnalysisOrchestrator(
data_file=sample_ticket_data,
output_dir=temp_output_dir
)
# 运行分析
result = orchestrator.run_analysis()
# 验证各阶段结果
assert orchestrator.data_profile is not None
assert orchestrator.requirement_spec is not None
assert orchestrator.analysis_plan is not None
assert len(orchestrator.analysis_results) > 0
assert orchestrator.report is not None
# 验证结果
assert result['success'] is True
class TestProgressTracking:
"""进度跟踪测试。"""
def test_progress_callback(self, sample_ticket_data, temp_output_dir):
"""
测试进度回调。
验证:
- 进度回调被正确调用
- 进度信息正确
"""
progress_calls = []
def callback(stage, current, total):
progress_calls.append({
'stage': stage,
'current': current,
'total': total
})
# 运行分析
result = run_analysis(
data_file=sample_ticket_data,
output_dir=temp_output_dir,
progress_callback=callback
)
# 验证进度回调
assert len(progress_calls) > 0
# 验证进度递增
for i in range(len(progress_calls) - 1):
assert progress_calls[i]['current'] <= progress_calls[i + 1]['current']
# 验证最后一个进度是完成状态
last_call = progress_calls[-1]
assert last_call['current'] == last_call['total']
class TestOutputFiles:
"""输出文件测试。"""
def test_report_file_creation(self, sample_ticket_data, temp_output_dir):
"""
测试报告文件创建。
验证:
- 报告文件被创建
- 报告文件格式正确
"""
result = run_analysis(
data_file=sample_ticket_data,
output_dir=temp_output_dir
)
assert result['success'] is True
# 验证报告文件
report_path = Path(result['report_path'])
assert report_path.exists()
assert report_path.suffix == '.md'
# 验证报告内容是 UTF-8 编码
content = report_path.read_text(encoding='utf-8')
assert len(content) > 0
def test_log_file_creation(self, sample_ticket_data, temp_output_dir):
"""
测试日志文件创建。
验证:
- 日志文件被创建(如果配置)
- 日志内容正确
"""
# 配置日志文件
from src.logging_config import setup_logging
import logging
log_file = Path(temp_output_dir) / "test.log"
setup_logging(
level=logging.INFO,
log_file=str(log_file)
)
# 运行分析
result = run_analysis(
data_file=sample_ticket_data,
output_dir=temp_output_dir
)
# 验证日志文件
if log_file.exists():
log_content = log_file.read_text(encoding='utf-8')
assert len(log_content) > 0
assert '数据理解' in log_content or 'INFO' in log_content

View File

@@ -1,320 +0,0 @@
"""Unit tests for core data models."""
import pytest
import json
from datetime import datetime
from src.models import (
ColumnInfo,
DataProfile,
AnalysisObjective,
RequirementSpec,
AnalysisTask,
AnalysisPlan,
AnalysisResult,
)
class TestColumnInfo:
"""Tests for ColumnInfo model."""
def test_create_column_info(self):
"""Test creating a ColumnInfo instance."""
col = ColumnInfo(
name='age',
dtype='numeric',
missing_rate=0.05,
unique_count=50,
sample_values=[25, 30, 35, 40, 45],
statistics={'mean': 35.5, 'std': 10.2}
)
assert col.name == 'age'
assert col.dtype == 'numeric'
assert col.missing_rate == 0.05
assert col.unique_count == 50
assert len(col.sample_values) == 5
assert col.statistics['mean'] == 35.5
def test_column_info_serialization(self):
"""Test ColumnInfo to_dict and from_dict."""
col = ColumnInfo(
name='status',
dtype='categorical',
missing_rate=0.0,
unique_count=3,
sample_values=['open', 'closed', 'pending']
)
col_dict = col.to_dict()
assert col_dict['name'] == 'status'
assert col_dict['dtype'] == 'categorical'
col_restored = ColumnInfo.from_dict(col_dict)
assert col_restored.name == col.name
assert col_restored.dtype == col.dtype
assert col_restored.sample_values == col.sample_values
def test_column_info_json(self):
"""Test ColumnInfo JSON serialization."""
col = ColumnInfo(
name='created_at',
dtype='datetime',
missing_rate=0.0,
unique_count=1000
)
json_str = col.to_json()
col_restored = ColumnInfo.from_json(json_str)
assert col_restored.name == col.name
assert col_restored.dtype == col.dtype
class TestDataProfile:
"""Tests for DataProfile model."""
def test_create_data_profile(self):
"""Test creating a DataProfile instance."""
columns = [
ColumnInfo(name='id', dtype='numeric', missing_rate=0.0, unique_count=100),
ColumnInfo(name='status', dtype='categorical', missing_rate=0.0, unique_count=3),
]
profile = DataProfile(
file_path='test.csv',
row_count=100,
column_count=2,
columns=columns,
inferred_type='ticket',
key_fields={'status': 'ticket status'},
quality_score=85.5,
summary='Test data profile'
)
assert profile.file_path == 'test.csv'
assert profile.row_count == 100
assert profile.inferred_type == 'ticket'
assert len(profile.columns) == 2
assert profile.quality_score == 85.5
def test_data_profile_serialization(self):
"""Test DataProfile to_dict and from_dict."""
columns = [
ColumnInfo(name='id', dtype='numeric', missing_rate=0.0, unique_count=100),
]
profile = DataProfile(
file_path='test.csv',
row_count=100,
column_count=1,
columns=columns,
inferred_type='sales'
)
profile_dict = profile.to_dict()
assert profile_dict['file_path'] == 'test.csv'
assert profile_dict['inferred_type'] == 'sales'
assert len(profile_dict['columns']) == 1
profile_restored = DataProfile.from_dict(profile_dict)
assert profile_restored.file_path == profile.file_path
assert profile_restored.row_count == profile.row_count
assert len(profile_restored.columns) == len(profile.columns)
class TestAnalysisObjective:
"""Tests for AnalysisObjective model."""
def test_create_objective(self):
"""Test creating an AnalysisObjective instance."""
obj = AnalysisObjective(
name='Health Analysis',
description='Analyze ticket health',
metrics=['close_rate', 'avg_duration'],
priority=5
)
assert obj.name == 'Health Analysis'
assert obj.priority == 5
assert len(obj.metrics) == 2
def test_objective_serialization(self):
"""Test AnalysisObjective serialization."""
obj = AnalysisObjective(
name='Test',
description='Test objective',
metrics=['metric1']
)
obj_dict = obj.to_dict()
obj_restored = AnalysisObjective.from_dict(obj_dict)
assert obj_restored.name == obj.name
assert obj_restored.metrics == obj.metrics
class TestRequirementSpec:
"""Tests for RequirementSpec model."""
def test_create_requirement_spec(self):
"""Test creating a RequirementSpec instance."""
objectives = [
AnalysisObjective(name='Obj1', description='First objective', metrics=['m1'])
]
spec = RequirementSpec(
user_input='Analyze ticket health',
objectives=objectives,
constraints=['no_pii'],
expected_outputs=['report', 'charts']
)
assert spec.user_input == 'Analyze ticket health'
assert len(spec.objectives) == 1
assert len(spec.constraints) == 1
def test_requirement_spec_serialization(self):
"""Test RequirementSpec serialization."""
objectives = [
AnalysisObjective(name='Obj1', description='Test', metrics=['m1'])
]
spec = RequirementSpec(
user_input='Test input',
objectives=objectives
)
spec_dict = spec.to_dict()
spec_restored = RequirementSpec.from_dict(spec_dict)
assert spec_restored.user_input == spec.user_input
assert len(spec_restored.objectives) == len(spec.objectives)
class TestAnalysisTask:
"""Tests for AnalysisTask model."""
def test_create_task(self):
"""Test creating an AnalysisTask instance."""
task = AnalysisTask(
id='task_1',
name='Calculate statistics',
description='Calculate basic statistics',
priority=5,
dependencies=['task_0'],
required_tools=['stats_tool'],
expected_output='Statistics summary'
)
assert task.id == 'task_1'
assert task.priority == 5
assert len(task.dependencies) == 1
assert task.status == 'pending'
def test_task_serialization(self):
"""Test AnalysisTask serialization."""
task = AnalysisTask(
id='task_1',
name='Test task',
description='Test',
priority=3
)
task_dict = task.to_dict()
task_restored = AnalysisTask.from_dict(task_dict)
assert task_restored.id == task.id
assert task_restored.name == task.name
class TestAnalysisPlan:
"""Tests for AnalysisPlan model."""
def test_create_plan(self):
"""Test creating an AnalysisPlan instance."""
objectives = [
AnalysisObjective(name='Obj1', description='Test', metrics=['m1'])
]
tasks = [
AnalysisTask(id='t1', name='Task 1', description='Test', priority=5)
]
plan = AnalysisPlan(
objectives=objectives,
tasks=tasks,
tool_config={'tool1': 'config1'},
estimated_duration=300
)
assert len(plan.objectives) == 1
assert len(plan.tasks) == 1
assert plan.estimated_duration == 300
assert isinstance(plan.created_at, datetime)
def test_plan_serialization(self):
"""Test AnalysisPlan serialization."""
objectives = [
AnalysisObjective(name='Obj1', description='Test', metrics=['m1'])
]
tasks = [
AnalysisTask(id='t1', name='Task 1', description='Test', priority=5)
]
plan = AnalysisPlan(objectives=objectives, tasks=tasks)
plan_dict = plan.to_dict()
plan_restored = AnalysisPlan.from_dict(plan_dict)
assert len(plan_restored.objectives) == len(plan.objectives)
assert len(plan_restored.tasks) == len(plan.tasks)
class TestAnalysisResult:
"""Tests for AnalysisResult model."""
def test_create_result(self):
"""Test creating an AnalysisResult instance."""
result = AnalysisResult(
task_id='task_1',
task_name='Test task',
success=True,
data={'count': 100},
visualizations=['chart1.png'],
insights=['Key finding 1'],
execution_time=5.5
)
assert result.task_id == 'task_1'
assert result.success is True
assert result.data['count'] == 100
assert len(result.insights) == 1
assert result.error is None
def test_result_with_error(self):
"""Test AnalysisResult with error."""
result = AnalysisResult(
task_id='task_1',
task_name='Failed task',
success=False,
error='Tool execution failed'
)
assert result.success is False
assert result.error == 'Tool execution failed'
def test_result_serialization(self):
"""Test AnalysisResult serialization."""
result = AnalysisResult(
task_id='task_1',
task_name='Test',
success=True,
data={'key': 'value'}
)
result_dict = result.to_dict()
result_restored = AnalysisResult.from_dict(result_dict)
assert result_restored.task_id == result.task_id
assert result_restored.success == result.success
assert result_restored.data == result.data

View File

@@ -1,586 +0,0 @@
"""性能测试 - 验证系统性能指标。
测试内容:
1. 数据理解阶段性能(< 30秒
2. 完整分析流程性能(< 30分钟
3. 大数据集处理100万行
4. 内存使用
需求NFR-1.1, NFR-1.2
"""
import pytest
import time
import pandas as pd
import numpy as np
import psutil
import os
from pathlib import Path
from typing import Dict, Any
from src.main import run_analysis
from src.data_access import DataAccessLayer
from src.engines.data_understanding import understand_data
class TestDataUnderstandingPerformance:
"""测试数据理解阶段的性能。"""
def test_small_dataset_performance(self, tmp_path):
"""测试小数据集1000行的性能。"""
# 生成测试数据
data_file = tmp_path / "small_data.csv"
df = self._generate_test_data(rows=1000, cols=10)
df.to_csv(data_file, index=False)
# 测试性能
start_time = time.time()
dal = DataAccessLayer.load_from_file(str(data_file))
profile = understand_data(dal)
elapsed = time.time() - start_time
# 验证应该在5秒内完成
assert elapsed < 5, f"小数据集理解耗时 {elapsed:.2f}超过5秒限制"
assert profile.row_count == 1000
assert profile.column_count == 10
def test_medium_dataset_performance(self, tmp_path):
"""测试中等数据集10万行的性能。"""
# 生成测试数据
data_file = tmp_path / "medium_data.csv"
df = self._generate_test_data(rows=100000, cols=20)
df.to_csv(data_file, index=False)
# 测试性能
start_time = time.time()
dal = DataAccessLayer.load_from_file(str(data_file))
profile = understand_data(dal)
elapsed = time.time() - start_time
# 验证应该在15秒内完成
assert elapsed < 15, f"中等数据集理解耗时 {elapsed:.2f}超过15秒限制"
assert profile.row_count == 100000
assert profile.column_count == 20
def test_large_dataset_performance(self, tmp_path):
"""测试大数据集100万行的性能。
需求NFR-1.1 - 数据理解阶段 < 30秒
需求NFR-1.2 - 支持最大100万行数据
"""
# 生成测试数据
data_file = tmp_path / "large_data.csv"
df = self._generate_test_data(rows=1000000, cols=30)
df.to_csv(data_file, index=False)
# 测试性能
start_time = time.time()
dal = DataAccessLayer.load_from_file(str(data_file))
profile = understand_data(dal)
elapsed = time.time() - start_time
# 验证应该在30秒内完成
assert elapsed < 30, f"大数据集理解耗时 {elapsed:.2f}超过30秒限制"
assert profile.row_count == 1000000
assert profile.column_count == 30
print(f"✓ 大数据集100万行理解耗时: {elapsed:.2f}")
def _generate_test_data(self, rows: int, cols: int) -> pd.DataFrame:
"""生成测试数据。"""
data = {}
# 生成不同类型的列
for i in range(cols):
col_type = i % 4
if col_type == 0: # 数值列
data[f'numeric_{i}'] = np.random.randn(rows)
elif col_type == 1: # 分类列
categories = ['A', 'B', 'C', 'D', 'E']
data[f'category_{i}'] = np.random.choice(categories, rows)
elif col_type == 2: # 日期列
start_date = pd.Timestamp('2020-01-01')
data[f'date_{i}'] = pd.date_range(start_date, periods=rows, freq='H')
else: # 文本列
data[f'text_{i}'] = [f'text_{j}' for j in range(rows)]
return pd.DataFrame(data)
class TestFullAnalysisPerformance:
"""测试完整分析流程的性能。"""
@pytest.mark.slow
def test_small_dataset_full_analysis(self, tmp_path):
"""测试小数据集的完整分析流程。"""
# 生成测试数据
data_file = tmp_path / "test_data.csv"
df = self._generate_ticket_data(rows=1000)
df.to_csv(data_file, index=False)
# 设置输出目录
output_dir = tmp_path / "output"
# 测试性能
start_time = time.time()
result = run_analysis(
data_file=str(data_file),
user_requirement="分析工单数据",
output_dir=str(output_dir)
)
elapsed = time.time() - start_time
# 验证应该在5分钟内完成
assert elapsed < 300, f"小数据集完整分析耗时 {elapsed:.2f}超过5分钟限制"
assert result['success'] is True
print(f"✓ 小数据集1000行完整分析耗时: {elapsed:.2f}")
@pytest.mark.slow
@pytest.mark.skipif(
os.getenv('SKIP_LONG_TESTS') == '1',
reason="跳过长时间运行的测试"
)
def test_large_dataset_full_analysis(self, tmp_path):
"""测试大数据集的完整分析流程。
需求NFR-1.1 - 完整分析流程 < 30分钟
"""
# 生成测试数据
data_file = tmp_path / "large_test_data.csv"
df = self._generate_ticket_data(rows=100000)
df.to_csv(data_file, index=False)
# 设置输出目录
output_dir = tmp_path / "output"
# 测试性能
start_time = time.time()
result = run_analysis(
data_file=str(data_file),
user_requirement="分析工单健康度",
output_dir=str(output_dir)
)
elapsed = time.time() - start_time
# 验证应该在30分钟内完成
assert elapsed < 1800, f"大数据集完整分析耗时 {elapsed:.2f}超过30分钟限制"
assert result['success'] is True
print(f"✓ 大数据集10万行完整分析耗时: {elapsed:.2f}")
def _generate_ticket_data(self, rows: int) -> pd.DataFrame:
"""生成工单测试数据。"""
statuses = ['待处理', '处理中', '已关闭', '已解决']
priorities = ['', '', '', '紧急']
types = ['故障', '咨询', '投诉', '建议']
models = ['Model A', 'Model B', 'Model C', 'Model D']
data = {
'ticket_id': [f'T{i:06d}' for i in range(rows)],
'status': np.random.choice(statuses, rows),
'priority': np.random.choice(priorities, rows),
'type': np.random.choice(types, rows),
'model': np.random.choice(models, rows),
'created_at': pd.date_range('2023-01-01', periods=rows, freq='5min'),
'closed_at': pd.date_range('2023-01-01', periods=rows, freq='5min') + pd.Timedelta(hours=24),
'duration_hours': np.random.randint(1, 100, rows),
}
return pd.DataFrame(data)
class TestMemoryUsage:
"""测试内存使用。"""
def test_data_loading_memory(self, tmp_path):
"""测试数据加载的内存使用。"""
# 生成测试数据
data_file = tmp_path / "memory_test.csv"
df = self._generate_test_data(rows=100000, cols=50)
df.to_csv(data_file, index=False)
# 记录初始内存
process = psutil.Process()
initial_memory = process.memory_info().rss / 1024 / 1024 # MB
# 加载数据
dal = DataAccessLayer.load_from_file(str(data_file))
profile = understand_data(dal)
# 记录最终内存
final_memory = process.memory_info().rss / 1024 / 1024 # MB
memory_increase = final_memory - initial_memory
# 验证内存增长应该合理不超过500MB
assert memory_increase < 500, f"内存增长 {memory_increase:.2f}MB超过500MB限制"
print(f"✓ 数据加载内存增长: {memory_increase:.2f}MB")
def test_large_dataset_memory(self, tmp_path):
"""测试大数据集的内存使用。
需求NFR-1.2 - 支持最大100MB的CSV文件
"""
# 生成测试数据约100MB
data_file = tmp_path / "large_memory_test.csv"
df = self._generate_test_data(rows=500000, cols=50)
df.to_csv(data_file, index=False)
# 检查文件大小
file_size = os.path.getsize(data_file) / 1024 / 1024 # MB
print(f"测试文件大小: {file_size:.2f}MB")
# 记录初始内存
process = psutil.Process()
initial_memory = process.memory_info().rss / 1024 / 1024 # MB
# 加载数据
dal = DataAccessLayer.load_from_file(str(data_file))
profile = understand_data(dal)
# 记录最终内存
final_memory = process.memory_info().rss / 1024 / 1024 # MB
memory_increase = final_memory - initial_memory
# 验证内存增长应该合理不超过1GB
assert memory_increase < 1024, f"内存增长 {memory_increase:.2f}MB超过1GB限制"
print(f"✓ 大数据集内存增长: {memory_increase:.2f}MB")
def _generate_test_data(self, rows: int, cols: int) -> pd.DataFrame:
"""生成测试数据。"""
data = {}
for i in range(cols):
col_type = i % 4
if col_type == 0:
data[f'col_{i}'] = np.random.randn(rows)
elif col_type == 1:
data[f'col_{i}'] = np.random.choice(['A', 'B', 'C', 'D'], rows)
elif col_type == 2:
data[f'col_{i}'] = pd.date_range('2020-01-01', periods=rows, freq='H')
else:
data[f'col_{i}'] = [f'text_{j % 1000}' for j in range(rows)]
return pd.DataFrame(data)
class TestStagePerformance:
"""测试各阶段的性能指标。"""
def test_data_understanding_stage(self, tmp_path):
"""测试数据理解阶段的性能。"""
# 生成测试数据
data_file = tmp_path / "stage_test.csv"
df = self._generate_test_data(rows=50000, cols=30)
df.to_csv(data_file, index=False)
# 测试性能
start_time = time.time()
dal = DataAccessLayer.load_from_file(str(data_file))
profile = understand_data(dal)
elapsed = time.time() - start_time
# 验证应该在20秒内完成
assert elapsed < 20, f"数据理解阶段耗时 {elapsed:.2f}超过20秒限制"
print(f"✓ 数据理解阶段5万行耗时: {elapsed:.2f}")
def _generate_test_data(self, rows: int, cols: int) -> pd.DataFrame:
"""生成测试数据。"""
data = {}
for i in range(cols):
if i % 3 == 0:
data[f'col_{i}'] = np.random.randn(rows)
elif i % 3 == 1:
data[f'col_{i}'] = np.random.choice(['A', 'B', 'C'], rows)
else:
data[f'col_{i}'] = pd.date_range('2020-01-01', periods=rows, freq='min')
return pd.DataFrame(data)
@pytest.fixture
def performance_report(tmp_path):
"""生成性能测试报告。"""
report_file = tmp_path / "performance_report.txt"
yield report_file
# 测试结束后,如果报告文件存在,打印内容
if report_file.exists():
print("\n" + "="*60)
print("性能测试报告")
print("="*60)
print(report_file.read_text())
print("="*60)
class TestOptimizationEffectiveness:
"""测试性能优化的有效性。"""
def test_memory_optimization(self, tmp_path):
"""测试内存优化的效果。"""
# 生成测试数据
data_file = tmp_path / "optimization_test.csv"
df = self._generate_test_data(rows=100000, cols=30)
df.to_csv(data_file, index=False)
# 不优化内存
dal_no_opt = DataAccessLayer.load_from_file(str(data_file), optimize_memory=False)
memory_no_opt = dal_no_opt._data.memory_usage(deep=True).sum() / 1024 / 1024
# 优化内存
dal_opt = DataAccessLayer.load_from_file(str(data_file), optimize_memory=True)
memory_opt = dal_opt._data.memory_usage(deep=True).sum() / 1024 / 1024
# 验证:优化后内存应该减少
memory_saved = memory_no_opt - memory_opt
savings_percent = (memory_saved / memory_no_opt) * 100
print(f"✓ 内存优化效果: {memory_no_opt:.2f}MB -> {memory_opt:.2f}MB")
print(f"✓ 节省内存: {memory_saved:.2f}MB ({savings_percent:.1f}%)")
# 验证至少节省10%的内存
assert memory_saved > 0, "内存优化应该减少内存使用"
def test_cache_effectiveness(self, tmp_path):
"""测试缓存的有效性。"""
from src.performance_optimization import LLMCache
cache_dir = tmp_path / "cache"
cache = LLMCache(str(cache_dir))
# 第一次调用(未缓存)
prompt = "测试提示"
response = {"result": "测试响应"}
# 设置缓存
cache.set(prompt, response)
# 第二次调用(应该命中缓存)
cached_response = cache.get(prompt)
assert cached_response is not None
assert cached_response == response
print("✓ 缓存功能正常工作")
def test_batch_processing(self):
"""测试批处理的效果。"""
from src.performance_optimization import BatchProcessor
processor = BatchProcessor(batch_size=10)
# 测试数据
items = list(range(100))
# 批处理函数
def process_item(item):
return item * 2
# 执行批处理
start_time = time.time()
results = processor.process_batch(items, process_item)
elapsed = time.time() - start_time
# 验证结果
assert len(results) == 100
assert results[0] == 0
assert results[50] == 100
print(f"✓ 批处理100个项目耗时: {elapsed:.3f}")
def _generate_test_data(self, rows: int, cols: int) -> pd.DataFrame:
"""生成测试数据。"""
data = {}
for i in range(cols):
if i % 3 == 0:
data[f'col_{i}'] = np.random.randint(0, 100, rows)
elif i % 3 == 1:
data[f'col_{i}'] = np.random.choice(['A', 'B', 'C', 'D'], rows)
else:
data[f'col_{i}'] = [f'text_{j % 100}' for j in range(rows)]
return pd.DataFrame(data)
class TestPerformanceMonitoring:
"""测试性能监控功能。"""
def test_performance_monitor(self):
"""测试性能监控器。"""
from src.performance_optimization import PerformanceMonitor
monitor = PerformanceMonitor()
# 记录一些指标
monitor.record("test_metric", 1.5)
monitor.record("test_metric", 2.0)
monitor.record("test_metric", 1.8)
# 获取统计信息
stats = monitor.get_stats("test_metric")
assert stats['count'] == 3
assert stats['mean'] == pytest.approx(1.767, rel=0.01)
assert stats['min'] == 1.5
assert stats['max'] == 2.0
print("✓ 性能监控器正常工作")
def test_timed_decorator(self):
"""测试计时装饰器。"""
from src.performance_optimization import timed, PerformanceMonitor
monitor = PerformanceMonitor()
@timed(metric_name="test_function", monitor=monitor)
def slow_function():
time.sleep(0.1)
return "done"
# 执行函数
result = slow_function()
assert result == "done"
# 检查是否记录了性能指标
stats = monitor.get_stats("test_function")
assert stats['count'] == 1
assert stats['mean'] >= 0.1
print("✓ 计时装饰器正常工作")
class TestEndToEndPerformance:
"""端到端性能测试。"""
def test_performance_report_generation(self, tmp_path):
"""测试性能报告生成。"""
from src.performance_optimization import get_global_monitor
# 生成测试数据
data_file = tmp_path / "e2e_test.csv"
df = self._generate_ticket_data(rows=5000)
df.to_csv(data_file, index=False)
# 获取性能监控器
monitor = get_global_monitor()
monitor.clear()
# 执行数据理解
dal = DataAccessLayer.load_from_file(str(data_file))
profile = understand_data(dal)
# 获取性能统计
stats = monitor.get_all_stats()
print("\n性能统计:")
for metric_name, metric_stats in stats.items():
if metric_stats:
print(f" {metric_name}: {metric_stats['mean']:.3f}")
assert profile is not None
def _generate_ticket_data(self, rows: int) -> pd.DataFrame:
"""生成工单测试数据。"""
statuses = ['待处理', '处理中', '已关闭']
types = ['故障', '咨询', '投诉']
data = {
'ticket_id': [f'T{i:06d}' for i in range(rows)],
'status': np.random.choice(statuses, rows),
'type': np.random.choice(types, rows),
'created_at': pd.date_range('2023-01-01', periods=rows, freq='5min'),
'duration': np.random.randint(1, 100, rows),
}
return pd.DataFrame(data)
class TestPerformanceBenchmarks:
"""性能基准测试。"""
def test_data_loading_benchmark(self, tmp_path, benchmark_report):
"""数据加载性能基准。"""
sizes = [1000, 10000, 100000]
results = []
for size in sizes:
data_file = tmp_path / f"benchmark_{size}.csv"
df = self._generate_test_data(rows=size, cols=20)
df.to_csv(data_file, index=False)
start_time = time.time()
dal = DataAccessLayer.load_from_file(str(data_file))
elapsed = time.time() - start_time
results.append({
'size': size,
'time': elapsed,
'rows_per_second': size / elapsed
})
# 打印基准结果
print("\n数据加载性能基准:")
print(f"{'行数':<10} {'耗时(秒)':<12} {'行/秒':<15}")
print("-" * 40)
for r in results:
print(f"{r['size']:<10} {r['time']:<12.3f} {r['rows_per_second']:<15.0f}")
def test_data_understanding_benchmark(self, tmp_path):
"""数据理解性能基准。"""
sizes = [1000, 10000, 50000]
results = []
for size in sizes:
data_file = tmp_path / f"understanding_{size}.csv"
df = self._generate_test_data(rows=size, cols=20)
df.to_csv(data_file, index=False)
dal = DataAccessLayer.load_from_file(str(data_file))
start_time = time.time()
profile = understand_data(dal)
elapsed = time.time() - start_time
results.append({
'size': size,
'time': elapsed,
'rows_per_second': size / elapsed
})
# 打印基准结果
print("\n数据理解性能基准:")
print(f"{'行数':<10} {'耗时(秒)':<12} {'行/秒':<15}")
print("-" * 40)
for r in results:
print(f"{r['size']:<10} {r['time']:<12.3f} {r['rows_per_second']:<15.0f}")
def _generate_test_data(self, rows: int, cols: int) -> pd.DataFrame:
"""生成测试数据。"""
data = {}
for i in range(cols):
if i % 3 == 0:
data[f'col_{i}'] = np.random.randn(rows)
elif i % 3 == 1:
data[f'col_{i}'] = np.random.choice(['A', 'B', 'C'], rows)
else:
data[f'col_{i}'] = pd.date_range('2020-01-01', periods=rows, freq='min')
return pd.DataFrame(data)
@pytest.fixture
def benchmark_report():
"""基准测试报告fixture。"""
yield
# 可以在这里生成报告文件

View File

@@ -1,159 +0,0 @@
"""Tests for dynamic plan adjustment."""
import pytest
from datetime import datetime
from src.engines.plan_adjustment import (
adjust_plan,
identify_anomalies,
_fallback_plan_adjustment
)
from src.models.analysis_plan import AnalysisPlan, AnalysisTask
from src.models.analysis_result import AnalysisResult
from src.models.requirement_spec import AnalysisObjective
# Feature: true-ai-agent, Property 8: 计划动态调整
def test_plan_adjustment_with_anomaly():
"""
Property 8: For any analysis plan and intermediate results, if results
contain anomaly findings, the plan adjustment function should be able to
generate new deep-dive tasks or adjust existing task priorities.
Validates: 场景4验收.2, 场景4验收.3, FR-3.3
"""
# Create plan
plan = AnalysisPlan(
objectives=[
AnalysisObjective(
name="数据分析",
description="分析数据",
metrics=[],
priority=3
)
],
tasks=[
AnalysisTask(
id="task_1",
name="Task 1",
description="First task",
priority=3,
status='completed'
),
AnalysisTask(
id="task_2",
name="Task 2",
description="Second task",
priority=3,
status='pending'
)
],
created_at=datetime.now(),
updated_at=datetime.now()
)
# Create results with anomaly
results = [
AnalysisResult(
task_id="task_1",
task_name="Task 1",
success=True,
insights=["发现异常某类别占比90%,远超正常范围"],
execution_time=1.0
)
]
# Adjust plan (using fallback)
adjusted_plan = _fallback_plan_adjustment(plan, results)
# Verify: Plan should be updated
assert adjusted_plan.updated_at >= plan.created_at
# Verify: Pending task priority should be increased
task_2 = next(t for t in adjusted_plan.tasks if t.id == "task_2")
assert task_2.priority >= 3
def test_identify_anomalies():
"""Test anomaly identification from results."""
results = [
AnalysisResult(
task_id="task_1",
task_name="Task 1",
success=True,
insights=["发现异常数据", "正常分布"],
execution_time=1.0
),
AnalysisResult(
task_id="task_2",
task_name="Task 2",
success=True,
insights=["一切正常"],
execution_time=1.0
)
]
anomalies = identify_anomalies(results)
# Should identify one anomaly
assert len(anomalies) >= 1
assert anomalies[0]['task_id'] == "task_1"
def test_plan_adjustment_no_anomaly():
"""Test plan adjustment when no anomalies found."""
plan = AnalysisPlan(
objectives=[],
tasks=[
AnalysisTask(
id="task_1",
name="Task 1",
description="First task",
priority=3,
status='completed'
)
],
created_at=datetime.now(),
updated_at=datetime.now()
)
results = [
AnalysisResult(
task_id="task_1",
task_name="Task 1",
success=True,
insights=["一切正常"],
execution_time=1.0
)
]
adjusted_plan = _fallback_plan_adjustment(plan, results)
# Should still update timestamp
assert adjusted_plan.updated_at >= plan.created_at
def test_identify_anomalies_empty_results():
"""Test anomaly identification with empty results."""
anomalies = identify_anomalies([])
assert anomalies == []
def test_identify_anomalies_failed_results():
"""Test that failed results are skipped."""
results = [
AnalysisResult(
task_id="task_1",
task_name="Task 1",
success=False,
error="Failed",
insights=["发现异常"],
execution_time=1.0
)
]
anomalies = identify_anomalies(results)
# Failed results should be skipped
assert len(anomalies) == 0

View File

@@ -1,523 +0,0 @@
"""报告生成引擎的单元测试。"""
import pytest
import tempfile
import os
from src.engines.report_generation import (
extract_key_findings,
organize_report_structure,
generate_report,
_categorize_insight,
_calculate_importance,
_generate_report_title,
_generate_default_sections
)
from src.models.analysis_result import AnalysisResult
from src.models.requirement_spec import RequirementSpec, AnalysisObjective
from src.models.data_profile import DataProfile, ColumnInfo
@pytest.fixture
def sample_results():
"""创建示例分析结果。"""
return [
AnalysisResult(
task_id='task1',
task_name='状态分布分析',
success=True,
data={'open': 50, 'closed': 30, 'pending': 20},
visualizations=['chart1.png'],
insights=[
'待处理工单占比50%,异常高',
'已关闭工单占比30%'
],
execution_time=2.5
),
AnalysisResult(
task_id='task2',
task_name='趋势分析',
success=True,
data={'trend': 'increasing'},
visualizations=['chart2.png'],
insights=[
'工单数量呈上升趋势',
'增长率为15%'
],
execution_time=3.2
),
AnalysisResult(
task_id='task3',
task_name='类型分析',
success=False,
data={},
visualizations=[],
insights=[],
error='数据缺少类型字段',
execution_time=0.1
)
]
@pytest.fixture
def sample_requirement():
"""创建示例需求规格。"""
return RequirementSpec(
user_input='分析工单健康度',
objectives=[
AnalysisObjective(
name='健康度分析',
description='评估工单处理的健康状况',
metrics=['关闭率', '处理时长', '积压情况'],
priority=5
)
]
)
@pytest.fixture
def sample_data_profile():
"""创建示例数据画像。"""
return DataProfile(
file_path='test.csv',
row_count=1000,
column_count=5,
columns=[
ColumnInfo(
name='status',
dtype='categorical',
missing_rate=0.0,
unique_count=3,
sample_values=['open', 'closed', 'pending']
),
ColumnInfo(
name='created_at',
dtype='datetime',
missing_rate=0.0,
unique_count=1000
)
],
inferred_type='ticket',
key_fields={'status': '状态', 'created_at': '创建时间'},
quality_score=85.0,
summary='工单数据包含1000条记录'
)
class TestExtractKeyFindings:
"""测试关键发现提炼。"""
def test_basic_functionality(self, sample_results):
"""测试基本功能。"""
key_findings = extract_key_findings(sample_results)
# 验证:返回列表
assert isinstance(key_findings, list)
# 验证:只包含成功的结果
assert len(key_findings) == 4 # 2个任务每个2个洞察
# 验证:每个发现都有必需的字段
for finding in key_findings:
assert 'finding' in finding
assert 'importance' in finding
assert 'source_task' in finding
assert 'category' in finding
def test_importance_sorting(self, sample_results):
"""测试按重要性排序。"""
key_findings = extract_key_findings(sample_results)
# 验证:按重要性降序排列
for i in range(len(key_findings) - 1):
assert key_findings[i]['importance'] >= key_findings[i + 1]['importance']
def test_empty_results(self):
"""测试空结果列表。"""
key_findings = extract_key_findings([])
assert isinstance(key_findings, list)
assert len(key_findings) == 0
def test_only_failed_results(self):
"""测试只有失败的结果。"""
results = [
AnalysisResult(
task_id='task1',
task_name='失败任务',
success=False,
error='测试错误'
)
]
key_findings = extract_key_findings(results)
# 失败的任务不应该产生发现
assert len(key_findings) == 0
class TestCategorizeInsight:
"""测试洞察分类。"""
def test_anomaly_detection(self):
"""测试异常检测。"""
insight = '待处理工单占比50%,异常高'
category = _categorize_insight(insight)
assert category == 'anomaly'
def test_trend_detection(self):
"""测试趋势检测。"""
insight = '工单数量呈上升趋势'
category = _categorize_insight(insight)
assert category == 'trend'
def test_general_insight(self):
"""测试一般洞察。"""
insight = '数据质量良好'
category = _categorize_insight(insight)
assert category == 'insight'
def test_english_keywords(self):
"""测试英文关键词。"""
assert _categorize_insight('This is an anomaly') == 'anomaly'
assert _categorize_insight('Showing growth trend') == 'trend'
class TestCalculateImportance:
"""测试重要性计算。"""
def test_anomaly_importance(self):
"""测试异常的重要性。"""
insight = '严重异常:系统故障'
importance = _calculate_importance(insight, {})
# 异常 + 严重 = 高重要性
assert importance >= 4
def test_percentage_importance(self):
"""测试包含百分比的重要性。"""
insight = '占比达到80%'
importance = _calculate_importance(insight, {})
# 包含百分比 = 较高重要性
assert importance >= 4
def test_normal_importance(self):
"""测试普通洞察的重要性。"""
insight = '数据正常'
importance = _calculate_importance(insight, {})
# 默认中等重要性
assert importance == 3
def test_importance_range(self):
"""测试重要性范围。"""
# 测试多个洞察确保重要性在1-5范围内
insights = [
'严重异常问题',
'占比80%',
'正常数据',
'轻微变化'
]
for insight in insights:
importance = _calculate_importance(insight, {})
assert 1 <= importance <= 5
class TestOrganizeReportStructure:
"""测试报告结构组织。"""
def test_basic_structure(self, sample_results, sample_requirement, sample_data_profile):
"""测试基本结构。"""
key_findings = extract_key_findings(sample_results)
structure = organize_report_structure(key_findings, sample_requirement, sample_data_profile)
# 验证:包含必需的字段
assert 'title' in structure
assert 'sections' in structure
assert 'executive_summary' in structure
assert 'detailed_analysis' in structure
assert 'conclusions' in structure
def test_with_template(self, sample_results, sample_data_profile):
"""测试使用模板的结构。"""
# 创建带模板的需求
requirement = RequirementSpec(
user_input='按模板分析',
objectives=[
AnalysisObjective(
name='分析',
description='按模板分析',
metrics=['指标1'],
priority=5
)
],
template_path='template.md',
template_requirements={
'sections': ['第一章', '第二章', '第三章'],
'required_metrics': ['指标1', '指标2'],
'required_charts': ['图表1']
}
)
key_findings = extract_key_findings(sample_results)
structure = organize_report_structure(key_findings, requirement, sample_data_profile)
# 验证:使用模板结构
assert structure['use_template'] is True
assert structure['sections'] == ['第一章', '第二章', '第三章']
def test_without_template(self, sample_results, sample_requirement, sample_data_profile):
"""测试不使用模板的结构。"""
key_findings = extract_key_findings(sample_results)
structure = organize_report_structure(key_findings, sample_requirement, sample_data_profile)
# 验证:生成默认结构
assert structure['use_template'] is False
assert len(structure['sections']) > 0
assert '执行摘要' in structure['sections']
def test_executive_summary(self, sample_results, sample_requirement, sample_data_profile):
"""测试执行摘要组织。"""
key_findings = extract_key_findings(sample_results)
structure = organize_report_structure(key_findings, sample_requirement, sample_data_profile)
exec_summary = structure['executive_summary']
# 验证:包含关键发现
assert 'key_findings' in exec_summary
assert isinstance(exec_summary['key_findings'], list)
# 验证:包含统计信息
assert 'anomaly_count' in exec_summary
assert 'trend_count' in exec_summary
def test_detailed_analysis(self, sample_results, sample_requirement, sample_data_profile):
"""测试详细分析组织。"""
key_findings = extract_key_findings(sample_results)
structure = organize_report_structure(key_findings, sample_requirement, sample_data_profile)
detailed = structure['detailed_analysis']
# 验证:包含分类
assert 'anomaly' in detailed
assert 'trend' in detailed
assert 'insight' in detailed
# 验证:每个分类都是列表
assert isinstance(detailed['anomaly'], list)
assert isinstance(detailed['trend'], list)
assert isinstance(detailed['insight'], list)
class TestGenerateReportTitle:
"""测试报告标题生成。"""
def test_health_analysis_title(self, sample_data_profile):
"""测试健康度分析标题。"""
requirement = RequirementSpec(
user_input='分析工单健康度',
objectives=[]
)
title = _generate_report_title(requirement, sample_data_profile)
assert '工单' in title
assert '健康度' in title
def test_trend_analysis_title(self, sample_data_profile):
"""测试趋势分析标题。"""
requirement = RequirementSpec(
user_input='分析趋势',
objectives=[]
)
title = _generate_report_title(requirement, sample_data_profile)
assert '工单' in title
assert '趋势' in title
def test_generic_title(self, sample_data_profile):
"""测试通用标题。"""
requirement = RequirementSpec(
user_input='分析数据',
objectives=[]
)
title = _generate_report_title(requirement, sample_data_profile)
assert '工单' in title
assert '分析报告' in title
class TestGenerateDefaultSections:
"""测试默认章节生成。"""
def test_with_anomalies(self):
"""测试包含异常的章节。"""
key_findings = [
{
'finding': '异常情况',
'category': 'anomaly',
'importance': 5
}
]
data_profile = DataProfile(
file_path='test.csv',
row_count=100,
column_count=3,
columns=[],
inferred_type='ticket'
)
sections = _generate_default_sections(key_findings, data_profile)
# 验证:包含异常分析章节
assert '异常分析' in sections
def test_with_trends(self):
"""测试包含趋势的章节。"""
key_findings = [
{
'finding': '上升趋势',
'category': 'trend',
'importance': 4
}
]
data_profile = DataProfile(
file_path='test.csv',
row_count=100,
column_count=3,
columns=[],
inferred_type='sales'
)
sections = _generate_default_sections(key_findings, data_profile)
# 验证:包含趋势分析章节
assert '趋势分析' in sections
def test_ticket_data_sections(self):
"""测试工单数据的章节。"""
data_profile = DataProfile(
file_path='test.csv',
row_count=100,
column_count=3,
columns=[],
inferred_type='ticket'
)
sections = _generate_default_sections([], data_profile)
# 验证:包含工单相关章节
assert '状态分析' in sections or '类型分析' in sections
class TestGenerateReport:
"""测试完整报告生成。"""
def test_basic_report_generation(self, sample_results, sample_requirement, sample_data_profile):
"""测试基本报告生成。"""
report = generate_report(sample_results, sample_requirement, sample_data_profile)
# 验证:返回字符串
assert isinstance(report, str)
# 验证:报告不为空
assert len(report) > 0
# 验证:包含标题
assert '#' in report
# 验证:包含执行摘要
assert '执行摘要' in report or '摘要' in report
def test_report_with_skipped_tasks(self, sample_results, sample_requirement, sample_data_profile):
"""测试包含跳过任务的报告。"""
report = generate_report(sample_results, sample_requirement, sample_data_profile)
# 验证:提到跳过的任务
assert '跳过' in report or '失败' in report
# 验证:提到失败的任务名称
assert '类型分析' in report
def test_report_with_visualizations(self, sample_results, sample_requirement, sample_data_profile):
"""测试包含可视化的报告。"""
report = generate_report(sample_results, sample_requirement, sample_data_profile)
# 验证:包含图表引用
assert 'chart1.png' in report or 'chart2.png' in report or '![' in report
def test_report_with_insights(self, sample_results, sample_requirement, sample_data_profile):
"""测试包含洞察的报告。"""
report = generate_report(sample_results, sample_requirement, sample_data_profile)
# 验证:包含洞察内容
assert '待处理工单' in report or '趋势' in report
def test_report_save_to_file(self, sample_results, sample_requirement, sample_data_profile):
"""测试报告保存到文件。"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.md', delete=False, encoding='utf-8') as f:
output_path = f.name
try:
report = generate_report(
sample_results,
sample_requirement,
sample_data_profile,
output_path=output_path
)
# 验证:文件已创建
assert os.path.exists(output_path)
# 验证:文件内容与返回内容一致
with open(output_path, 'r', encoding='utf-8') as f:
saved_content = f.read()
assert saved_content == report
finally:
if os.path.exists(output_path):
os.unlink(output_path)
def test_empty_results(self, sample_requirement, sample_data_profile):
"""测试空结果列表。"""
report = generate_report([], sample_requirement, sample_data_profile)
# 验证:仍然生成报告
assert isinstance(report, str)
assert len(report) > 0
# 验证:包含基本结构
assert '执行摘要' in report or '摘要' in report
def test_all_failed_results(self, sample_requirement, sample_data_profile):
"""测试所有任务都失败的情况。"""
results = [
AnalysisResult(
task_id='task1',
task_name='失败任务1',
success=False,
error='错误1'
),
AnalysisResult(
task_id='task2',
task_name='失败任务2',
success=False,
error='错误2'
)
]
report = generate_report(results, sample_requirement, sample_data_profile)
# 验证:报告生成成功
assert isinstance(report, str)
assert len(report) > 0
# 验证:提到失败
assert '失败' in report or '跳过' in report

View File

@@ -1,332 +0,0 @@
"""报告生成引擎的属性测试。
使用 hypothesis 进行基于属性的测试,验证报告生成的通用正确性属性。
"""
import pytest
from hypothesis import given, strategies as st, settings
import tempfile
import os
from src.engines.report_generation import (
extract_key_findings,
organize_report_structure,
generate_report
)
from src.models.analysis_result import AnalysisResult
from src.models.requirement_spec import RequirementSpec, AnalysisObjective
from src.models.data_profile import DataProfile, ColumnInfo
# 策略:生成随机的分析结果
@st.composite
def analysis_result_strategy(draw):
"""生成随机的分析结果。"""
task_id = draw(st.text(min_size=1, max_size=20))
task_name = draw(st.text(min_size=1, max_size=50))
success = draw(st.booleans())
# 生成洞察
insights = draw(st.lists(
st.text(min_size=10, max_size=100),
min_size=0,
max_size=5
))
# 生成可视化路径
visualizations = draw(st.lists(
st.text(min_size=5, max_size=50),
min_size=0,
max_size=3
))
return AnalysisResult(
task_id=task_id,
task_name=task_name,
success=success,
data={'result': 'test'},
visualizations=visualizations,
insights=insights,
error=None if success else "Test error",
execution_time=draw(st.floats(min_value=0.1, max_value=100.0))
)
# 策略:生成随机的需求规格
@st.composite
def requirement_spec_strategy(draw):
"""生成随机的需求规格。"""
user_input = draw(st.text(min_size=1, max_size=100))
# 生成分析目标
objectives = draw(st.lists(
st.builds(
AnalysisObjective,
name=st.text(min_size=1, max_size=30),
description=st.text(min_size=1, max_size=100),
metrics=st.lists(st.text(min_size=1, max_size=20), min_size=1, max_size=5),
priority=st.integers(min_value=1, max_value=5)
),
min_size=1,
max_size=5
))
# 可能有模板
has_template = draw(st.booleans())
template_path = "template.md" if has_template else None
template_requirements = {
'sections': ['执行摘要', '详细分析', '结论'],
'required_metrics': ['指标1', '指标2'],
'required_charts': ['图表1']
} if has_template else None
return RequirementSpec(
user_input=user_input,
objectives=objectives,
template_path=template_path,
template_requirements=template_requirements
)
# 策略:生成随机的数据画像
@st.composite
def data_profile_strategy(draw):
"""生成随机的数据画像。"""
columns = draw(st.lists(
st.builds(
ColumnInfo,
name=st.text(min_size=1, max_size=20),
dtype=st.sampled_from(['numeric', 'categorical', 'datetime', 'text']),
missing_rate=st.floats(min_value=0.0, max_value=1.0),
unique_count=st.integers(min_value=1, max_value=1000),
sample_values=st.lists(st.text(), min_size=0, max_size=5),
statistics=st.dictionaries(st.text(), st.floats())
),
min_size=1,
max_size=10
))
return DataProfile(
file_path=draw(st.text(min_size=1, max_size=50)),
row_count=draw(st.integers(min_value=1, max_value=1000000)),
column_count=len(columns),
columns=columns,
inferred_type=draw(st.sampled_from(['ticket', 'sales', 'user', 'unknown'])),
key_fields=draw(st.dictionaries(st.text(), st.text())),
quality_score=draw(st.floats(min_value=0.0, max_value=100.0)),
summary=draw(st.text(min_size=0, max_size=200))
)
# Feature: true-ai-agent, Property 16: 报告结构完整性
@given(
results=st.lists(analysis_result_strategy(), min_size=1, max_size=10),
requirement=requirement_spec_strategy(),
data_profile=data_profile_strategy()
)
@settings(max_examples=20, deadline=None)
def test_property_16_report_structure_completeness(results, requirement, data_profile):
"""
属性 16报告结构完整性
对于任何分析结果集合和需求规格,生成的报告应该包含执行摘要、
详细分析和结论建议三个主要部分,并且如果使用了模板,
报告结构应该遵循模板的章节组织。
验证需求场景3验收.3, FR-6.2
"""
# 生成报告
report = generate_report(results, requirement, data_profile)
# 验证:报告不为空
assert len(report) > 0, "报告内容不应为空"
# 验证:包含执行摘要
assert '执行摘要' in report or 'Executive Summary' in report or '摘要' in report, \
"报告应包含执行摘要部分"
# 验证:包含详细分析
assert '详细分析' in report or 'Detailed Analysis' in report or '分析' in report, \
"报告应包含详细分析部分"
# 验证:包含结论或建议
assert '结论' in report or '建议' in report or 'Conclusion' in report or 'Recommendation' in report, \
"报告应包含结论与建议部分"
# 如果使用了模板,验证模板章节
if requirement.template_path and requirement.template_requirements:
template_sections = requirement.template_requirements.get('sections', [])
# 至少应该提到一些模板章节
if template_sections:
# 检查是否有任何模板章节出现在报告中
sections_found = sum(1 for section in template_sections if section in report)
# 至少应该有一些章节被包含或提及
assert sections_found >= 0, "报告应该参考模板结构"
# Feature: true-ai-agent, Property 17: 报告内容追溯性
@given(
results=st.lists(analysis_result_strategy(), min_size=1, max_size=10),
requirement=requirement_spec_strategy(),
data_profile=data_profile_strategy()
)
@settings(max_examples=20, deadline=None)
def test_property_17_report_content_traceability(results, requirement, data_profile):
"""
属性 17报告内容追溯性
对于任何生成的报告和分析结果集合,报告中提到的所有发现和数据
应该能够追溯到某个分析结果,并且如果某些计划中的分析被跳过,
报告应该说明原因。
验证需求场景3验收.4, 场景4验收.4, FR-6.1
"""
# 生成报告
report = generate_report(results, requirement, data_profile)
# 验证:报告不为空
assert len(report) > 0, "报告内容不应为空"
# 检查失败的任务
failed_tasks = [r for r in results if not r.success]
if failed_tasks:
# 验证:如果有失败的任务,报告应该提到跳过或失败
has_skip_mention = any(
keyword in report
for keyword in ['跳过', '失败', 'skipped', 'failed', '错误', 'error']
)
assert has_skip_mention, "报告应该说明哪些分析被跳过或失败"
# 验证至少提到一个失败任务的名称或ID
task_mentioned = any(
task.task_name in report or task.task_id in report
for task in failed_tasks
)
# 注意:由于任务名称可能很短或通用,这个检查可能不总是通过
# 所以我们只检查是否有失败提及
# 检查成功的任务
successful_tasks = [r for r in results if r.success]
if successful_tasks:
# 验证:成功的任务应该在报告中有所体现
# 至少应该有一些洞察或发现被包含
has_insights = any(
any(insight in report for insight in task.insights)
for task in successful_tasks
if task.insights
)
# 或者至少提到了任务
has_task_mention = any(
task.task_name in report or task.task_id in report
for task in successful_tasks
)
# 至少应该有洞察或任务提及之一
# 注意:由于文本生成的随机性,我们放宽这个要求
# 只要报告包含了分析相关的内容即可
assert len(report) > 100, "报告应该包含足够的分析内容"
# 辅助测试:验证关键发现提炼
@given(results=st.lists(analysis_result_strategy(), min_size=1, max_size=20))
@settings(max_examples=20, deadline=None)
def test_extract_key_findings_structure(results):
"""测试关键发现提炼的结构。"""
key_findings = extract_key_findings(results)
# 验证:返回列表
assert isinstance(key_findings, list), "应该返回列表"
# 验证:每个发现都有必需的字段
for finding in key_findings:
assert 'finding' in finding, "发现应该包含finding字段"
assert 'importance' in finding, "发现应该包含importance字段"
assert 'source_task' in finding, "发现应该包含source_task字段"
assert 'category' in finding, "发现应该包含category字段"
# 验证重要性在1-5范围内
assert 1 <= finding['importance'] <= 5, "重要性应该在1-5范围内"
# 验证:类别是有效的
assert finding['category'] in ['anomaly', 'trend', 'insight'], \
"类别应该是anomaly、trend或insight之一"
# 验证:按重要性降序排列
if len(key_findings) > 1:
for i in range(len(key_findings) - 1):
assert key_findings[i]['importance'] >= key_findings[i + 1]['importance'], \
"关键发现应该按重要性降序排列"
# 辅助测试:验证报告结构组织
@given(
results=st.lists(analysis_result_strategy(), min_size=1, max_size=10),
requirement=requirement_spec_strategy(),
data_profile=data_profile_strategy()
)
@settings(max_examples=20, deadline=None)
def test_organize_report_structure_completeness(results, requirement, data_profile):
"""测试报告结构组织的完整性。"""
# 提炼关键发现
key_findings = extract_key_findings(results)
# 组织报告结构
structure = organize_report_structure(key_findings, requirement, data_profile)
# 验证:包含必需的字段
assert 'title' in structure, "结构应该包含标题"
assert 'sections' in structure, "结构应该包含章节列表"
assert 'executive_summary' in structure, "结构应该包含执行摘要"
assert 'detailed_analysis' in structure, "结构应该包含详细分析"
assert 'conclusions' in structure, "结构应该包含结论"
# 验证:标题不为空
assert len(structure['title']) > 0, "标题不应为空"
# 验证:章节列表是列表
assert isinstance(structure['sections'], list), "章节应该是列表"
# 验证:执行摘要包含关键发现
assert 'key_findings' in structure['executive_summary'], \
"执行摘要应该包含关键发现"
# 验证:详细分析包含分类
assert 'anomaly' in structure['detailed_analysis'], \
"详细分析应该包含异常分类"
assert 'trend' in structure['detailed_analysis'], \
"详细分析应该包含趋势分类"
assert 'insight' in structure['detailed_analysis'], \
"详细分析应该包含洞察分类"
# 验证:结论包含摘要
assert 'summary' in structure['conclusions'], \
"结论应该包含摘要"
assert 'recommendations' in structure['conclusions'], \
"结论应该包含建议"
# 辅助测试:验证报告生成不会崩溃
@given(
results=st.lists(analysis_result_strategy(), min_size=0, max_size=5),
requirement=requirement_spec_strategy(),
data_profile=data_profile_strategy()
)
@settings(max_examples=10, deadline=None)
def test_generate_report_no_crash(results, requirement, data_profile):
"""测试报告生成不会崩溃(即使输入为空或异常)。"""
try:
# 生成报告
report = generate_report(results, requirement, data_profile)
# 验证:返回字符串
assert isinstance(report, str), "应该返回字符串"
# 验证:报告不为空(即使没有结果也应该有基本结构)
assert len(report) > 0, "报告不应为空"
except Exception as e:
# 报告生成不应该抛出异常
pytest.fail(f"报告生成不应该崩溃: {e}")

View File

@@ -1,328 +0,0 @@
"""Unit tests for requirement understanding engine."""
import pytest
import tempfile
import os
from src.engines.requirement_understanding import (
understand_requirement,
parse_template,
check_data_requirement_match,
_fallback_requirement_understanding
)
from src.models.data_profile import DataProfile, ColumnInfo
from src.models.requirement_spec import RequirementSpec, AnalysisObjective
@pytest.fixture
def sample_data_profile():
"""Create a sample data profile for testing."""
return DataProfile(
file_path='test.csv',
row_count=1000,
column_count=5,
columns=[
ColumnInfo(
name='created_at',
dtype='datetime',
missing_rate=0.0,
unique_count=1000,
sample_values=['2024-01-01', '2024-01-02'],
statistics={}
),
ColumnInfo(
name='status',
dtype='categorical',
missing_rate=0.1,
unique_count=5,
sample_values=['open', 'closed', 'pending'],
statistics={}
),
ColumnInfo(
name='type',
dtype='categorical',
missing_rate=0.0,
unique_count=10,
sample_values=['bug', 'feature'],
statistics={}
),
ColumnInfo(
name='priority',
dtype='numeric',
missing_rate=0.0,
unique_count=5,
sample_values=[1, 2, 3, 4, 5],
statistics={'mean': 3.0, 'std': 1.2}
),
ColumnInfo(
name='description',
dtype='text',
missing_rate=0.05,
unique_count=950,
sample_values=['Issue 1', 'Issue 2'],
statistics={}
)
],
inferred_type='ticket',
key_fields={'time': 'created_at', 'status': 'status', 'type': 'type'},
quality_score=85.0,
summary='Ticket data with 1000 rows and 5 columns'
)
def test_understand_health_requirement(sample_data_profile):
"""Test understanding "健康度" requirement."""
user_input = "我想了解工单的健康度"
# Use fallback to avoid API dependency
requirement = _fallback_requirement_understanding(user_input, sample_data_profile, None)
# Verify basic structure
assert isinstance(requirement, RequirementSpec)
assert requirement.user_input == user_input
assert len(requirement.objectives) > 0
# Verify health-related objective exists
health_objectives = [obj for obj in requirement.objectives if '健康' in obj.name]
assert len(health_objectives) > 0
# Verify objective has metrics
health_obj = health_objectives[0]
assert len(health_obj.metrics) > 0
assert health_obj.priority >= 1 and health_obj.priority <= 5
def test_understand_trend_requirement(sample_data_profile):
"""Test understanding trend analysis requirement."""
user_input = "分析趋势"
requirement = _fallback_requirement_understanding(user_input, sample_data_profile, None)
# Verify trend objective exists
trend_objectives = [obj for obj in requirement.objectives if '趋势' in obj.name]
assert len(trend_objectives) > 0
# Verify metrics
trend_obj = trend_objectives[0]
assert len(trend_obj.metrics) > 0
def test_understand_distribution_requirement(sample_data_profile):
"""Test understanding distribution analysis requirement."""
user_input = "查看分布情况"
requirement = _fallback_requirement_understanding(user_input, sample_data_profile, None)
# Verify distribution objective exists
dist_objectives = [obj for obj in requirement.objectives if '分布' in obj.name]
assert len(dist_objectives) > 0
def test_understand_generic_requirement(sample_data_profile):
"""Test understanding generic requirement without specific keywords."""
user_input = "帮我分析一下"
requirement = _fallback_requirement_understanding(user_input, sample_data_profile, None)
# Should still generate at least one objective
assert len(requirement.objectives) > 0
# Should have default objective
assert any('综合' in obj.name or 'analysis' in obj.name.lower() for obj in requirement.objectives)
def test_parse_template_with_sections():
"""Test parsing template with sections."""
template_content = """# 分析报告
## 数据概览
这是数据概览部分
## 趋势分析
指标: 增长率, 变化趋势
图表: 时间序列图
## 分布分析
指标: 类别分布
图表: 柱状图, 饼图
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.md', delete=False, encoding='utf-8') as f:
f.write(template_content)
template_path = f.name
try:
template_req = parse_template(template_path)
# Verify sections
assert len(template_req['sections']) >= 3
assert '分析报告' in template_req['sections']
assert '数据概览' in template_req['sections']
# Verify metrics
assert len(template_req['required_metrics']) >= 2
# Verify charts
assert len(template_req['required_charts']) >= 2
finally:
os.unlink(template_path)
def test_parse_nonexistent_template():
"""Test parsing non-existent template."""
template_req = parse_template('nonexistent.md')
# Should return empty structure
assert template_req['sections'] == []
assert template_req['required_metrics'] == []
assert template_req['required_charts'] == []
def test_check_data_satisfies_requirement(sample_data_profile):
"""Test checking when data satisfies requirement."""
# Create requirement that data can satisfy
requirement = RequirementSpec(
user_input="分析状态分布",
objectives=[
AnalysisObjective(
name="状态分析",
description="分析状态字段的分布",
metrics=["状态分布"],
priority=5
)
]
)
match_result = check_data_requirement_match(requirement, sample_data_profile)
# Should be satisfied
assert match_result['can_proceed'] is True
assert len(match_result['satisfied_objectives']) > 0
def test_check_data_missing_fields(sample_data_profile):
"""Test checking when data is missing required fields."""
# Create requirement that needs fields not in data
requirement = RequirementSpec(
user_input="分析地理分布",
objectives=[
AnalysisObjective(
name="地理分析",
description="分析地理位置分布",
metrics=["地理分布", "区域统计"],
priority=5
)
]
)
match_result = check_data_requirement_match(requirement, sample_data_profile)
# Verify structure
assert isinstance(match_result, dict)
assert 'missing_fields' in match_result
assert 'unsatisfied_objectives' in match_result
def test_check_time_based_requirement(sample_data_profile):
"""Test checking time-based requirement."""
requirement = RequirementSpec(
user_input="分析时间趋势",
objectives=[
AnalysisObjective(
name="时间分析",
description="分析随时间的变化",
metrics=["时间序列", "趋势"],
priority=5
)
]
)
match_result = check_data_requirement_match(requirement, sample_data_profile)
# Should be satisfied since we have datetime column
assert match_result['can_proceed'] is True
def test_check_status_based_requirement(sample_data_profile):
"""Test checking status-based requirement."""
requirement = RequirementSpec(
user_input="分析状态",
objectives=[
AnalysisObjective(
name="状态分析",
description="分析状态字段",
metrics=["状态分布", "状态变化"],
priority=5
)
]
)
match_result = check_data_requirement_match(requirement, sample_data_profile)
# Should be satisfied since we have status column
assert match_result['can_proceed'] is True
assert len(match_result['satisfied_objectives']) > 0
def test_requirement_with_template(sample_data_profile):
"""Test requirement understanding with template."""
template_content = """# 工单分析报告
## 状态分析
指标: 状态分布, 完成率
## 类型分析
指标: 类型分布
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.md', delete=False, encoding='utf-8') as f:
f.write(template_content)
template_path = f.name
try:
requirement = _fallback_requirement_understanding(
"按模板分析",
sample_data_profile,
template_path
)
# Verify template is included
assert requirement.template_path == template_path
assert requirement.template_requirements is not None
# Verify template requirements structure
assert 'sections' in requirement.template_requirements
assert 'required_metrics' in requirement.template_requirements
finally:
os.unlink(template_path)
def test_multiple_objectives_priority():
"""Test that multiple objectives have proper priorities."""
data_profile = DataProfile(
file_path='test.csv',
row_count=100,
column_count=3,
columns=[
ColumnInfo(name='col1', dtype='numeric', missing_rate=0.0, unique_count=100),
ColumnInfo(name='col2', dtype='categorical', missing_rate=0.0, unique_count=5),
ColumnInfo(name='col3', dtype='datetime', missing_rate=0.0, unique_count=100)
],
inferred_type='unknown',
quality_score=90.0
)
requirement = _fallback_requirement_understanding(
"完整分析,包括健康度和趋势",
data_profile,
None
)
# Should have multiple objectives
assert len(requirement.objectives) >= 2
# All priorities should be valid
for obj in requirement.objectives:
assert 1 <= obj.priority <= 5

View File

@@ -1,244 +0,0 @@
"""Property-based tests for requirement understanding engine."""
import pytest
from hypothesis import given, strategies as st, settings, assume
import tempfile
import os
from src.engines.requirement_understanding import (
understand_requirement,
parse_template,
check_data_requirement_match
)
from src.models.data_profile import DataProfile, ColumnInfo
from src.models.requirement_spec import RequirementSpec, AnalysisObjective
# Strategies for generating test data
@st.composite
def column_info_strategy(draw):
"""Generate random ColumnInfo."""
name = draw(st.text(min_size=1, max_size=20, alphabet=st.characters(whitelist_categories=('L', 'N'))))
dtype = draw(st.sampled_from(['numeric', 'categorical', 'datetime', 'text']))
missing_rate = draw(st.floats(min_value=0.0, max_value=1.0))
unique_count = draw(st.integers(min_value=1, max_value=1000))
return ColumnInfo(
name=name,
dtype=dtype,
missing_rate=missing_rate,
unique_count=unique_count,
sample_values=[],
statistics={}
)
@st.composite
def data_profile_strategy(draw):
"""Generate random DataProfile."""
row_count = draw(st.integers(min_value=10, max_value=100000))
columns = draw(st.lists(column_info_strategy(), min_size=2, max_size=20))
inferred_type = draw(st.sampled_from(['ticket', 'sales', 'user', 'unknown']))
quality_score = draw(st.floats(min_value=0.0, max_value=100.0))
return DataProfile(
file_path='test.csv',
row_count=row_count,
column_count=len(columns),
columns=columns,
inferred_type=inferred_type,
key_fields={},
quality_score=quality_score,
summary=f"Test data with {len(columns)} columns"
)
# Feature: true-ai-agent, Property 3: 抽象需求转化
@given(
user_input=st.sampled_from([
"分析健康度",
"我想了解数据质量",
"帮我分析趋势",
"查看分布情况",
"完整分析"
]),
data_profile=data_profile_strategy()
)
@settings(max_examples=20, deadline=None)
def test_abstract_requirement_transformation(user_input, data_profile):
"""
Property 3: For any abstract user requirement (like "健康度", "质量分析"),
the requirement understanding engine should be able to transform it into
a concrete list of analysis objectives, each containing name, description,
and related metrics.
Validates: 场景2验收.1, 场景2验收.2
"""
# Execute requirement understanding
requirement = understand_requirement(user_input, data_profile)
# Verify: Should return RequirementSpec
assert isinstance(requirement, RequirementSpec)
# Verify: Should have objectives
assert len(requirement.objectives) > 0, "Should generate at least one objective"
# Verify: Each objective should have required fields
for objective in requirement.objectives:
assert isinstance(objective, AnalysisObjective)
assert len(objective.name) > 0, "Objective name should not be empty"
assert len(objective.description) > 0, "Objective description should not be empty"
assert isinstance(objective.metrics, list), "Metrics should be a list"
assert 1 <= objective.priority <= 5, "Priority should be between 1 and 5"
# Verify: User input should be preserved
assert requirement.user_input == user_input
# Feature: true-ai-agent, Property 4: 模板解析
@given(
template_content=st.text(min_size=10, max_size=500)
)
@settings(max_examples=20, deadline=None)
def test_template_parsing(template_content):
"""
Property 4: For any valid analysis template, the requirement understanding
engine should be able to parse the template structure and extract the list
of required metrics and charts.
Validates: 场景3验收.1
"""
# Create temporary template file
with tempfile.NamedTemporaryFile(mode='w', suffix='.md', delete=False, encoding='utf-8') as f:
f.write(template_content)
template_path = f.name
try:
# Parse template
template_req = parse_template(template_path)
# Verify: Should return dictionary with expected keys
assert isinstance(template_req, dict)
assert 'sections' in template_req
assert 'required_metrics' in template_req
assert 'required_charts' in template_req
# Verify: All values should be lists
assert isinstance(template_req['sections'], list)
assert isinstance(template_req['required_metrics'], list)
assert isinstance(template_req['required_charts'], list)
finally:
# Cleanup
os.unlink(template_path)
# Feature: true-ai-agent, Property 5: 数据-需求匹配检查
@given(
data_profile=data_profile_strategy()
)
@settings(max_examples=20, deadline=None)
def test_data_requirement_matching(data_profile):
"""
Property 5: For any requirement spec and data profile, the requirement
understanding engine should be able to identify whether the data satisfies
the requirement, and if not, should mark missing fields or capabilities.
Validates: 场景3验收.2
"""
# Create a simple requirement
requirement = RequirementSpec(
user_input="测试需求",
objectives=[
AnalysisObjective(
name="时间分析",
description="分析时间趋势",
metrics=["时间序列", "趋势"],
priority=5
),
AnalysisObjective(
name="状态分析",
description="分析状态分布",
metrics=["状态分布"],
priority=4
)
]
)
# Check match
match_result = check_data_requirement_match(requirement, data_profile)
# Verify: Should return dictionary with expected keys
assert isinstance(match_result, dict)
assert 'all_satisfied' in match_result
assert 'satisfied_objectives' in match_result
assert 'unsatisfied_objectives' in match_result
assert 'missing_fields' in match_result
assert 'can_proceed' in match_result
# Verify: Boolean fields should be boolean
assert isinstance(match_result['all_satisfied'], bool)
assert isinstance(match_result['can_proceed'], bool)
# Verify: List fields should be lists
assert isinstance(match_result['satisfied_objectives'], list)
assert isinstance(match_result['unsatisfied_objectives'], list)
assert isinstance(match_result['missing_fields'], list)
# Verify: Satisfied + unsatisfied should equal total objectives
total_checked = len(match_result['satisfied_objectives']) + len(match_result['unsatisfied_objectives'])
assert total_checked == len(requirement.objectives)
# Verify: If all satisfied, should have no unsatisfied objectives
if match_result['all_satisfied']:
assert len(match_result['unsatisfied_objectives']) == 0
assert len(match_result['missing_fields']) == 0
# Verify: If can proceed, should have at least one satisfied objective
if match_result['can_proceed']:
assert len(match_result['satisfied_objectives']) > 0
# Feature: true-ai-agent, Property 3: 抽象需求转化 (with template)
@given(
user_input=st.text(min_size=5, max_size=100),
data_profile=data_profile_strategy()
)
@settings(max_examples=20, deadline=None)
def test_requirement_with_template(user_input, data_profile):
"""
Property 3 (extended): Requirement understanding should work with templates.
Validates: FR-2.3
"""
# Create a simple template
template_content = """# 分析报告
## 数据概览
指标: 行数, 列数
## 趋势分析
图表: 时间序列图
## 分布分析
图表: 分布图
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.md', delete=False, encoding='utf-8') as f:
f.write(template_content)
template_path = f.name
try:
# Execute with template
requirement = understand_requirement(user_input, data_profile, template_path)
# Verify: Should have template path
assert requirement.template_path == template_path
# Verify: Should have template requirements
assert requirement.template_requirements is not None
assert isinstance(requirement.template_requirements, dict)
finally:
# Cleanup
os.unlink(template_path)

View File

@@ -1,207 +0,0 @@
"""Unit tests for task execution engine."""
import pytest
import pandas as pd
from src.engines.task_execution import (
execute_task,
call_tool,
extract_insights,
_fallback_task_execution,
_find_tool
)
from src.models.analysis_plan import AnalysisTask
from src.data_access import DataAccessLayer
from src.tools.stats_tools import CalculateStatisticsTool
from src.tools.query_tools import GetValueCountsTool
@pytest.fixture
def sample_data():
"""Create sample data for testing."""
return pd.DataFrame({
'value': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
'category': ['A', 'B', 'A', 'B', 'A', 'B', 'A', 'B', 'A', 'B'],
'score': [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
})
@pytest.fixture
def sample_tools():
"""Create sample tools for testing."""
return [
CalculateStatisticsTool(),
GetValueCountsTool()
]
def test_fallback_execution_success(sample_data, sample_tools):
"""Test successful fallback execution."""
task = AnalysisTask(
id="task_1",
name="Calculate Statistics",
description="Calculate basic statistics",
priority=5,
required_tools=['calculate_statistics']
)
data_access = DataAccessLayer(sample_data)
result = _fallback_task_execution(task, sample_tools, data_access)
assert result.task_id == "task_1"
assert result.task_name == "Calculate Statistics"
assert isinstance(result.success, bool)
assert result.execution_time >= 0
def test_fallback_execution_no_tools(sample_data):
"""Test fallback execution with no tools."""
task = AnalysisTask(
id="task_1",
name="Test Task",
description="Test",
priority=3,
required_tools=['nonexistent_tool']
)
data_access = DataAccessLayer(sample_data)
result = _fallback_task_execution(task, [], data_access)
assert not result.success
assert result.error is not None
def test_call_tool_success(sample_data, sample_tools):
"""Test successful tool calling."""
tool = sample_tools[0] # CalculateStatisticsTool
data_access = DataAccessLayer(sample_data)
result = call_tool(tool, data_access, column='value')
assert isinstance(result, dict)
assert 'success' in result
def test_call_tool_with_invalid_params(sample_data, sample_tools):
"""Test tool calling with invalid parameters."""
tool = sample_tools[0]
data_access = DataAccessLayer(sample_data)
result = call_tool(tool, data_access, column='nonexistent_column')
assert isinstance(result, dict)
# Should handle error gracefully
def test_extract_insights_simple():
"""Test simple insight extraction."""
history = [
{'type': 'thought', 'content': 'Starting analysis'},
{'type': 'action', 'tool': 'calculate_statistics', 'params': {}},
{'type': 'observation', 'result': {'data': {'mean': 5.5, 'std': 2.87}}}
]
insights = extract_insights(history, client=None)
assert isinstance(insights, list)
assert len(insights) > 0
def test_extract_insights_empty_history():
"""Test insight extraction with empty history."""
insights = extract_insights([], client=None)
assert isinstance(insights, list)
def test_find_tool_exists(sample_tools):
"""Test finding an existing tool."""
tool = _find_tool(sample_tools, 'calculate_statistics')
assert tool is not None
assert tool.name == 'calculate_statistics'
def test_find_tool_not_exists(sample_tools):
"""Test finding a non-existent tool."""
tool = _find_tool(sample_tools, 'nonexistent_tool')
assert tool is None
def test_execution_result_structure(sample_data, sample_tools):
"""Test that execution result has correct structure."""
task = AnalysisTask(
id="task_1",
name="Test Task",
description="Test",
priority=3,
required_tools=['calculate_statistics']
)
data_access = DataAccessLayer(sample_data)
result = _fallback_task_execution(task, sample_tools, data_access)
# Check all required fields
assert hasattr(result, 'task_id')
assert hasattr(result, 'task_name')
assert hasattr(result, 'success')
assert hasattr(result, 'data')
assert hasattr(result, 'visualizations')
assert hasattr(result, 'insights')
assert hasattr(result, 'error')
assert hasattr(result, 'execution_time')
def test_execution_with_multiple_tools(sample_data, sample_tools):
"""Test execution with multiple required tools."""
task = AnalysisTask(
id="task_1",
name="Multi-tool Task",
description="Use multiple tools",
priority=3,
required_tools=['calculate_statistics', 'get_value_counts']
)
data_access = DataAccessLayer(sample_data)
result = _fallback_task_execution(task, sample_tools, data_access)
# Should execute first available tool
assert result is not None
def test_execution_time_tracking(sample_data, sample_tools):
"""Test that execution time is tracked."""
task = AnalysisTask(
id="task_1",
name="Test Task",
description="Test",
priority=3,
required_tools=['calculate_statistics']
)
data_access = DataAccessLayer(sample_data)
result = _fallback_task_execution(task, sample_tools, data_access)
assert result.execution_time >= 0
assert result.execution_time < 10 # Should be fast
def test_execution_with_empty_data():
"""Test execution with empty data."""
empty_data = pd.DataFrame()
task = AnalysisTask(
id="task_1",
name="Test Task",
description="Test",
priority=3,
required_tools=['calculate_statistics']
)
data_access = DataAccessLayer(empty_data)
tools = [CalculateStatisticsTool()]
result = _fallback_task_execution(task, tools, data_access)
# Should handle gracefully
assert result is not None

View File

@@ -1,202 +0,0 @@
"""Property-based tests for task execution engine."""
import pytest
import pandas as pd
from hypothesis import given, strategies as st, settings
from src.engines.task_execution import (
execute_task,
call_tool,
extract_insights,
_fallback_task_execution
)
from src.models.analysis_plan import AnalysisTask
from src.data_access import DataAccessLayer
from src.tools.stats_tools import CalculateStatisticsTool
# Feature: true-ai-agent, Property 13: 任务执行完整性
@given(
task_name=st.text(min_size=5, max_size=50),
task_description=st.text(min_size=10, max_size=100)
)
@settings(max_examples=10, deadline=None)
def test_task_execution_completeness(task_name, task_description):
"""
Property 13: For any valid analysis plan and tool set, the task execution
engine should be able to execute all non-skipped tasks and generate an
analysis result (success or failure) for each task.
Validates: 场景1验收.3, FR-5.1
"""
# Create sample data
sample_data = pd.DataFrame({
'value': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
'category': ['A', 'B', 'A', 'B', 'A', 'B', 'A', 'B', 'A', 'B']
})
# Create sample tools
sample_tools = [CalculateStatisticsTool()]
# Create task
task = AnalysisTask(
id="test_task",
name=task_name,
description=task_description,
priority=3,
required_tools=['calculate_statistics']
)
# Create data access
data_access = DataAccessLayer(sample_data)
# Execute task (using fallback to avoid API dependency)
result = _fallback_task_execution(task, sample_tools, data_access)
# Verify: Should return AnalysisResult
assert result is not None
assert result.task_id == task.id
assert result.task_name == task.name
# Verify: Should have success status
assert isinstance(result.success, bool)
# Verify: Should have execution time
assert result.execution_time >= 0
# Verify: If failed, should have error message
if not result.success:
assert result.error is not None
# Verify: Should have insights (even if empty)
assert isinstance(result.insights, list)
# Feature: true-ai-agent, Property 14: ReAct 循环终止
def test_react_loop_termination():
"""
Property 14: For any analysis task, the ReAct execution loop should
terminate within a finite number of steps (either complete the task
or reach maximum iterations), and should not loop infinitely.
Validates: FR-5.1
"""
sample_data = pd.DataFrame({
'value': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
'category': ['A', 'B', 'A', 'B', 'A', 'B', 'A', 'B', 'A', 'B']
})
sample_tools = [CalculateStatisticsTool()]
task = AnalysisTask(
id="test_task",
name="Test Task",
description="Calculate statistics",
priority=3,
required_tools=['calculate_statistics']
)
data_access = DataAccessLayer(sample_data)
# Execute with limited iterations
result = _fallback_task_execution(task, sample_tools, data_access)
# Verify: Should complete (not hang)
assert result is not None
# Verify: Should have finite execution time
assert result.execution_time < 60, "Execution should complete within 60 seconds"
# Feature: true-ai-agent, Property 15: 异常识别
def test_anomaly_identification():
"""
Property 15: For any data containing obvious anomalies (e.g., a category
accounting for >80% of data, or values exceeding 3 standard deviations),
the task execution engine should be able to mark the anomaly in the
analysis result insights.
Validates: 场景4验收.1
"""
# Create data with anomaly (category A is 90%)
anomaly_data = pd.DataFrame({
'value': list(range(100)),
'category': ['A'] * 90 + ['B'] * 10
})
task = AnalysisTask(
id="test_task",
name="Anomaly Detection",
description="Detect anomalies in data",
priority=3,
required_tools=['calculate_statistics']
)
data_access = DataAccessLayer(anomaly_data)
tools = [CalculateStatisticsTool()]
result = _fallback_task_execution(task, tools, data_access)
# Verify: Should complete successfully
assert result.success or result.error is not None
# Verify: Should have insights
assert isinstance(result.insights, list)
# Test tool calling
def test_call_tool_success():
"""Test successful tool calling."""
sample_data = pd.DataFrame({
'value': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
'category': ['A', 'B', 'A', 'B', 'A', 'B', 'A', 'B', 'A', 'B']
})
tool = CalculateStatisticsTool()
data_access = DataAccessLayer(sample_data)
result = call_tool(tool, data_access, column='value')
# Should return result dict
assert isinstance(result, dict)
assert 'success' in result
# Test insight extraction
def test_extract_insights_without_ai():
"""Test insight extraction without AI."""
history = [
{'type': 'thought', 'content': 'Analyzing data'},
{'type': 'action', 'tool': 'calculate_statistics'},
{'type': 'observation', 'result': {'data': {'mean': 5.5}}}
]
insights = extract_insights(history, client=None)
# Should return list of insights
assert isinstance(insights, list)
assert len(insights) > 0
# Test execution with empty tools
def test_execution_with_no_tools():
"""Test execution when no tools are available."""
sample_data = pd.DataFrame({
'value': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
'category': ['A', 'B', 'A', 'B', 'A', 'B', 'A', 'B', 'A', 'B']
})
task = AnalysisTask(
id="test_task",
name="Test Task",
description="Test",
priority=3,
required_tools=['nonexistent_tool']
)
data_access = DataAccessLayer(sample_data)
result = _fallback_task_execution(task, [], data_access)
# Should fail gracefully
assert not result.success
assert result.error is not None

View File

@@ -1,680 +0,0 @@
"""工具系统的单元测试。"""
import pytest
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from src.tools.base import AnalysisTool, ToolRegistry
from src.tools.query_tools import (
GetColumnDistributionTool,
GetValueCountsTool,
GetTimeSeriesTool,
GetCorrelationTool
)
from src.tools.stats_tools import (
CalculateStatisticsTool,
PerformGroupbyTool,
DetectOutliersTool,
CalculateTrendTool
)
from src.models import DataProfile, ColumnInfo
class TestGetColumnDistributionTool:
"""测试列分布工具。"""
def test_basic_functionality(self):
"""测试基本功能。"""
tool = GetColumnDistributionTool()
df = pd.DataFrame({
'status': ['open', 'closed', 'open', 'pending', 'closed', 'open']
})
result = tool.execute(df, column='status')
assert 'distribution' in result
assert result['column'] == 'status'
assert result['total_count'] == 6
assert result['unique_count'] == 3
assert len(result['distribution']) == 3
def test_top_n_limit(self):
"""测试 top_n 参数限制。"""
tool = GetColumnDistributionTool()
df = pd.DataFrame({
'value': list(range(20))
})
result = tool.execute(df, column='value', top_n=5)
assert len(result['distribution']) == 5
def test_nonexistent_column(self):
"""测试不存在的列。"""
tool = GetColumnDistributionTool()
df = pd.DataFrame({'col1': [1, 2, 3]})
result = tool.execute(df, column='nonexistent')
assert 'error' in result
class TestGetValueCountsTool:
"""测试值计数工具。"""
def test_basic_functionality(self):
"""测试基本功能。"""
tool = GetValueCountsTool()
df = pd.DataFrame({
'category': ['A', 'B', 'A', 'C', 'B', 'A']
})
result = tool.execute(df, column='category')
assert 'value_counts' in result
assert result['value_counts']['A'] == 3
assert result['value_counts']['B'] == 2
assert result['value_counts']['C'] == 1
def test_normalized_counts(self):
"""测试归一化计数。"""
tool = GetValueCountsTool()
df = pd.DataFrame({
'category': ['A', 'A', 'B', 'B']
})
result = tool.execute(df, column='category', normalize=True)
assert result['normalized'] is True
assert abs(result['value_counts']['A'] - 0.5) < 0.01
assert abs(result['value_counts']['B'] - 0.5) < 0.01
class TestGetTimeSeriesTool:
"""测试时间序列工具。"""
def test_basic_functionality(self):
"""测试基本功能。"""
tool = GetTimeSeriesTool()
dates = pd.date_range('2020-01-01', periods=10, freq='D')
df = pd.DataFrame({
'date': dates,
'value': range(10)
})
result = tool.execute(df, time_column='date', value_column='value', aggregation='sum')
assert 'time_series' in result
assert result['time_column'] == 'date'
assert result['aggregation'] == 'sum'
assert len(result['time_series']) > 0
def test_count_aggregation(self):
"""测试计数聚合。"""
tool = GetTimeSeriesTool()
dates = pd.date_range('2020-01-01', periods=5, freq='D')
df = pd.DataFrame({'date': dates})
result = tool.execute(df, time_column='date', aggregation='count')
assert 'time_series' in result
assert len(result['time_series']) > 0
def test_output_limit(self):
"""测试输出限制不超过100行"""
tool = GetTimeSeriesTool()
dates = pd.date_range('2020-01-01', periods=200, freq='D')
df = pd.DataFrame({'date': dates})
result = tool.execute(df, time_column='date')
assert len(result['time_series']) <= 100
assert result['total_points'] == 200
assert result['returned_points'] == 100
class TestGetCorrelationTool:
"""测试相关性分析工具。"""
def test_basic_functionality(self):
"""测试基本功能。"""
tool = GetCorrelationTool()
df = pd.DataFrame({
'x': [1, 2, 3, 4, 5],
'y': [2, 4, 6, 8, 10],
'z': [1, 1, 1, 1, 1]
})
result = tool.execute(df)
assert 'correlation_matrix' in result
assert 'x' in result['correlation_matrix']
assert 'y' in result['correlation_matrix']
# x 和 y 完全正相关
assert abs(result['correlation_matrix']['x']['y'] - 1.0) < 0.01
def test_insufficient_numeric_columns(self):
"""测试数值列不足的情况。"""
tool = GetCorrelationTool()
df = pd.DataFrame({
'x': [1, 2, 3],
'text': ['a', 'b', 'c']
})
result = tool.execute(df)
assert 'error' in result
class TestCalculateStatisticsTool:
"""测试统计计算工具。"""
def test_basic_functionality(self):
"""测试基本功能。"""
tool = CalculateStatisticsTool()
df = pd.DataFrame({
'values': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
})
result = tool.execute(df, column='values')
assert result['mean'] == 5.5
assert result['median'] == 5.5
assert result['min'] == 1
assert result['max'] == 10
assert result['count'] == 10
def test_non_numeric_column(self):
"""测试非数值列。"""
tool = CalculateStatisticsTool()
df = pd.DataFrame({
'text': ['a', 'b', 'c']
})
result = tool.execute(df, column='text')
assert 'error' in result
class TestPerformGroupbyTool:
"""测试分组聚合工具。"""
def test_basic_functionality(self):
"""测试基本功能。"""
tool = PerformGroupbyTool()
df = pd.DataFrame({
'category': ['A', 'B', 'A', 'B', 'A'],
'value': [10, 20, 30, 40, 50]
})
result = tool.execute(df, group_by='category', value_column='value', aggregation='sum')
assert 'groups' in result
assert len(result['groups']) == 2
# 找到 A 组的总和
group_a = next(g for g in result['groups'] if g['group'] == 'A')
assert group_a['value'] == 90 # 10 + 30 + 50
def test_count_aggregation(self):
"""测试计数聚合。"""
tool = PerformGroupbyTool()
df = pd.DataFrame({
'category': ['A', 'B', 'A', 'B', 'A']
})
result = tool.execute(df, group_by='category')
assert len(result['groups']) == 2
group_a = next(g for g in result['groups'] if g['group'] == 'A')
assert group_a['value'] == 3
def test_output_limit(self):
"""测试输出限制不超过100组"""
tool = PerformGroupbyTool()
df = pd.DataFrame({
'category': [f'cat_{i}' for i in range(200)],
'value': range(200)
})
result = tool.execute(df, group_by='category', value_column='value', aggregation='sum')
assert len(result['groups']) <= 100
assert result['total_groups'] == 200
assert result['returned_groups'] == 100
class TestDetectOutliersTool:
"""测试异常值检测工具。"""
def test_iqr_method(self):
"""测试 IQR 方法。"""
tool = DetectOutliersTool()
# 创建包含明显异常值的数据
df = pd.DataFrame({
'values': [1, 2, 3, 4, 5, 6, 7, 8, 9, 100]
})
result = tool.execute(df, column='values', method='iqr')
assert result['outlier_count'] > 0
assert 100 in result['outlier_values']
def test_zscore_method(self):
"""测试 Z-score 方法。"""
tool = DetectOutliersTool()
df = pd.DataFrame({
'values': [1, 2, 3, 4, 5, 6, 7, 8, 9, 100]
})
result = tool.execute(df, column='values', method='zscore', threshold=2)
assert result['outlier_count'] > 0
assert result['method'] == 'zscore'
class TestCalculateTrendTool:
"""测试趋势计算工具。"""
def test_increasing_trend(self):
"""测试上升趋势。"""
tool = CalculateTrendTool()
dates = pd.date_range('2020-01-01', periods=10, freq='D')
df = pd.DataFrame({
'date': dates,
'value': range(10)
})
result = tool.execute(df, time_column='date', value_column='value')
assert result['trend'] == 'increasing'
assert result['slope'] > 0
assert result['r_squared'] > 0.9 # 完美线性关系
def test_decreasing_trend(self):
"""测试下降趋势。"""
tool = CalculateTrendTool()
dates = pd.date_range('2020-01-01', periods=10, freq='D')
df = pd.DataFrame({
'date': dates,
'value': list(range(10, 0, -1))
})
result = tool.execute(df, time_column='date', value_column='value')
assert result['trend'] == 'decreasing'
assert result['slope'] < 0
class TestToolParameterValidation:
"""测试工具参数验证。"""
def test_missing_required_parameter(self):
"""测试缺少必需参数。"""
tool = GetColumnDistributionTool()
df = pd.DataFrame({'col': [1, 2, 3]})
# 不提供必需的 column 参数
result = tool.execute(df)
# 应该返回错误或引发异常
assert 'error' in result or result is None
def test_invalid_aggregation_method(self):
"""测试无效的聚合方法。"""
tool = PerformGroupbyTool()
df = pd.DataFrame({
'category': ['A', 'B'],
'value': [1, 2]
})
result = tool.execute(df, group_by='category', value_column='value', aggregation='invalid')
assert 'error' in result
class TestToolErrorHandling:
"""测试工具错误处理。"""
def test_empty_dataframe(self):
"""测试空 DataFrame。"""
tool = CalculateStatisticsTool()
df = pd.DataFrame()
result = tool.execute(df, column='nonexistent')
assert 'error' in result
def test_all_null_values(self):
"""测试全部为空值的列。"""
tool = CalculateStatisticsTool()
df = pd.DataFrame({
'values': [None, None, None]
})
result = tool.execute(df, column='values')
# 应该处理空值情况
assert 'error' in result or result['count'] == 0
def test_invalid_date_column(self):
"""测试无效的日期列。"""
tool = GetTimeSeriesTool()
df = pd.DataFrame({
'not_date': ['a', 'b', 'c']
})
result = tool.execute(df, time_column='not_date')
assert 'error' in result
class TestToolRegistry:
"""测试工具注册表。"""
def test_register_and_retrieve(self):
"""测试注册和检索工具。"""
registry = ToolRegistry()
tool = GetColumnDistributionTool()
registry.register(tool)
retrieved = registry.get_tool(tool.name)
assert retrieved.name == tool.name
def test_unregister(self):
"""测试注销工具。"""
registry = ToolRegistry()
tool = GetColumnDistributionTool()
registry.register(tool)
registry.unregister(tool.name)
with pytest.raises(KeyError):
registry.get_tool(tool.name)
def test_list_tools(self):
"""测试列出所有工具。"""
registry = ToolRegistry()
tool1 = GetColumnDistributionTool()
tool2 = GetValueCountsTool()
registry.register(tool1)
registry.register(tool2)
tools = registry.list_tools()
assert len(tools) == 2
assert tool1.name in tools
assert tool2.name in tools
def test_get_applicable_tools(self):
"""测试获取适用的工具。"""
registry = ToolRegistry()
# 注册所有工具
registry.register(GetColumnDistributionTool())
registry.register(CalculateStatisticsTool())
registry.register(GetTimeSeriesTool())
# 创建包含数值和时间列的数据画像
profile = DataProfile(
file_path='test.csv',
row_count=100,
column_count=2,
columns=[
ColumnInfo(name='value', dtype='numeric', missing_rate=0.0, unique_count=50),
ColumnInfo(name='date', dtype='datetime', missing_rate=0.0, unique_count=100)
],
inferred_type='unknown'
)
applicable = registry.get_applicable_tools(profile)
# 所有工具都应该适用GetColumnDistributionTool 适用于所有数据)
assert len(applicable) > 0
class TestToolManager:
"""测试工具管理器。"""
def test_select_tools_for_datetime_data(self):
"""测试为包含时间字段的数据选择工具。"""
from src.tools.tool_manager import ToolManager
# 创建工具注册表并注册所有工具
registry = ToolRegistry()
registry.register(GetTimeSeriesTool())
registry.register(CalculateTrendTool())
registry.register(GetColumnDistributionTool())
manager = ToolManager(registry)
# 创建包含时间字段的数据画像
profile = DataProfile(
file_path='test.csv',
row_count=100,
column_count=1,
columns=[
ColumnInfo(name='date', dtype='datetime', missing_rate=0.0, unique_count=100)
],
inferred_type='unknown',
key_fields={},
quality_score=100.0,
summary='Test data'
)
tools = manager.select_tools(profile)
tool_names = [tool.name for tool in tools]
# 应该包含时间序列工具
assert 'get_time_series' in tool_names
assert 'calculate_trend' in tool_names
def test_select_tools_for_numeric_data(self):
"""测试为包含数值字段的数据选择工具。"""
from src.tools.tool_manager import ToolManager
registry = ToolRegistry()
registry.register(CalculateStatisticsTool())
registry.register(DetectOutliersTool())
registry.register(GetCorrelationTool())
manager = ToolManager(registry)
# 创建包含数值字段的数据画像
profile = DataProfile(
file_path='test.csv',
row_count=100,
column_count=2,
columns=[
ColumnInfo(name='value1', dtype='numeric', missing_rate=0.0, unique_count=50),
ColumnInfo(name='value2', dtype='numeric', missing_rate=0.0, unique_count=50)
],
inferred_type='unknown',
key_fields={},
quality_score=100.0,
summary='Test data'
)
tools = manager.select_tools(profile)
tool_names = [tool.name for tool in tools]
# 应该包含统计工具
assert 'calculate_statistics' in tool_names
assert 'detect_outliers' in tool_names
assert 'get_correlation' in tool_names
def test_select_tools_for_categorical_data(self):
"""测试为包含分类字段的数据选择工具。"""
from src.tools.tool_manager import ToolManager
registry = ToolRegistry()
registry.register(GetColumnDistributionTool())
registry.register(GetValueCountsTool())
registry.register(PerformGroupbyTool())
manager = ToolManager(registry)
# 创建包含分类字段的数据画像
profile = DataProfile(
file_path='test.csv',
row_count=100,
column_count=1,
columns=[
ColumnInfo(name='category', dtype='categorical', missing_rate=0.0, unique_count=5)
],
inferred_type='unknown',
key_fields={},
quality_score=100.0,
summary='Test data'
)
tools = manager.select_tools(profile)
tool_names = [tool.name for tool in tools]
# 应该包含分类工具
assert 'get_column_distribution' in tool_names
assert 'get_value_counts' in tool_names
assert 'perform_groupby' in tool_names
def test_no_geo_tools_for_non_geo_data(self):
"""测试不为非地理数据选择地理工具。"""
from src.tools.tool_manager import ToolManager
registry = ToolRegistry()
registry.register(GetColumnDistributionTool())
manager = ToolManager(registry)
# 创建不包含地理字段的数据画像
profile = DataProfile(
file_path='test.csv',
row_count=100,
column_count=1,
columns=[
ColumnInfo(name='value', dtype='numeric', missing_rate=0.0, unique_count=50)
],
inferred_type='unknown',
key_fields={},
quality_score=100.0,
summary='Test data'
)
tools = manager.select_tools(profile)
tool_names = [tool.name for tool in tools]
# 不应该包含地理工具
assert 'create_map_visualization' not in tool_names
def test_identify_missing_tools(self):
"""测试识别缺失的工具。"""
from src.tools.tool_manager import ToolManager
# 创建空的工具注册表
empty_registry = ToolRegistry()
manager = ToolManager(empty_registry)
# 创建包含时间字段的数据画像
profile = DataProfile(
file_path='test.csv',
row_count=100,
column_count=1,
columns=[
ColumnInfo(name='date', dtype='datetime', missing_rate=0.0, unique_count=100)
],
inferred_type='unknown',
key_fields={},
quality_score=100.0,
summary='Test data'
)
# 尝试选择工具
tools = manager.select_tools(profile)
# 获取缺失的工具
missing = manager.get_missing_tools()
# 应该识别出缺失的时间序列工具
assert len(missing) > 0
assert any(tool in missing for tool in ['get_time_series', 'calculate_trend'])
def test_clear_missing_tools(self):
"""测试清空缺失工具列表。"""
from src.tools.tool_manager import ToolManager
empty_registry = ToolRegistry()
manager = ToolManager(empty_registry)
# 创建数据画像并选择工具(会记录缺失工具)
profile = DataProfile(
file_path='test.csv',
row_count=100,
column_count=1,
columns=[
ColumnInfo(name='date', dtype='datetime', missing_rate=0.0, unique_count=100)
],
inferred_type='unknown',
key_fields={},
quality_score=100.0,
summary='Test data'
)
manager.select_tools(profile)
assert len(manager.get_missing_tools()) > 0
# 清空缺失工具列表
manager.clear_missing_tools()
assert len(manager.get_missing_tools()) == 0
def test_get_tool_descriptions(self):
"""测试获取工具描述。"""
from src.tools.tool_manager import ToolManager
registry = ToolRegistry()
tool1 = GetColumnDistributionTool()
tool2 = CalculateStatisticsTool()
registry.register(tool1)
registry.register(tool2)
manager = ToolManager(registry)
tools = [tool1, tool2]
descriptions = manager.get_tool_descriptions(tools)
assert len(descriptions) == 2
assert all('name' in desc for desc in descriptions)
assert all('description' in desc for desc in descriptions)
assert all('parameters' in desc for desc in descriptions)
def test_tool_deduplication(self):
"""测试工具去重。"""
from src.tools.tool_manager import ToolManager
registry = ToolRegistry()
# 注册一个工具,它可能被多个类别选中
tool = GetColumnDistributionTool()
registry.register(tool)
manager = ToolManager(registry)
# 创建包含多种类型字段的数据画像
profile = DataProfile(
file_path='test.csv',
row_count=100,
column_count=2,
columns=[
ColumnInfo(name='category', dtype='categorical', missing_rate=0.0, unique_count=5),
ColumnInfo(name='value', dtype='numeric', missing_rate=0.0, unique_count=50)
],
inferred_type='unknown',
key_fields={},
quality_score=100.0,
summary='Test data'
)
tools = manager.select_tools(profile)
tool_names = [tool.name for tool in tools]
# 工具名称应该是唯一的(没有重复)
assert len(tool_names) == len(set(tool_names))

View File

@@ -1,620 +0,0 @@
"""工具系统的基于属性的测试。"""
import pytest
import pandas as pd
import numpy as np
from hypothesis import given, strategies as st, settings, assume
from typing import Dict, Any
from src.tools.base import AnalysisTool, ToolRegistry
from src.tools.query_tools import (
GetColumnDistributionTool,
GetValueCountsTool,
GetTimeSeriesTool,
GetCorrelationTool
)
from src.tools.stats_tools import (
CalculateStatisticsTool,
PerformGroupbyTool,
DetectOutliersTool,
CalculateTrendTool
)
from src.models import DataProfile, ColumnInfo
# Hypothesis 策略用于生成测试数据
@st.composite
def column_info_strategy(draw):
"""生成随机的 ColumnInfo 实例。"""
dtype = draw(st.sampled_from(['numeric', 'categorical', 'datetime', 'text']))
return ColumnInfo(
name=draw(st.text(min_size=1, max_size=20, alphabet=st.characters(whitelist_categories=('Lu', 'Ll')))),
dtype=dtype,
missing_rate=draw(st.floats(min_value=0.0, max_value=1.0)),
unique_count=draw(st.integers(min_value=1, max_value=1000)),
sample_values=draw(st.lists(st.integers(), min_size=1, max_size=5)),
statistics={'mean': draw(st.floats(allow_nan=False, allow_infinity=False))} if dtype == 'numeric' else {}
)
@st.composite
def data_profile_strategy(draw):
"""生成随机的 DataProfile 实例。"""
columns = draw(st.lists(column_info_strategy(), min_size=1, max_size=10))
return DataProfile(
file_path=draw(st.text(min_size=1, max_size=50)),
row_count=draw(st.integers(min_value=1, max_value=10000)),
column_count=len(columns),
columns=columns,
inferred_type=draw(st.sampled_from(['ticket', 'sales', 'user', 'unknown'])),
key_fields={},
quality_score=draw(st.floats(min_value=0.0, max_value=100.0)),
summary=draw(st.text(max_size=100))
)
@st.composite
def dataframe_strategy(draw, min_rows=10, max_rows=100, min_cols=2, max_cols=10):
"""生成随机的 DataFrame 实例。"""
n_rows = draw(st.integers(min_value=min_rows, max_value=max_rows))
n_cols = draw(st.integers(min_value=min_cols, max_value=max_cols))
data = {}
for i in range(n_cols):
col_type = draw(st.sampled_from(['int', 'float', 'str']))
col_name = f'col_{i}'
if col_type == 'int':
data[col_name] = draw(st.lists(
st.integers(min_value=-1000, max_value=1000),
min_size=n_rows,
max_size=n_rows
))
elif col_type == 'float':
data[col_name] = draw(st.lists(
st.floats(min_value=-1000.0, max_value=1000.0, allow_nan=False, allow_infinity=False),
min_size=n_rows,
max_size=n_rows
))
else: # str
data[col_name] = draw(st.lists(
st.text(min_size=1, max_size=10, alphabet=st.characters(whitelist_categories=('Lu', 'Ll'))),
min_size=n_rows,
max_size=n_rows
))
return pd.DataFrame(data)
# 获取所有工具类用于测试
ALL_TOOLS = [
GetColumnDistributionTool,
GetValueCountsTool,
GetTimeSeriesTool,
GetCorrelationTool,
CalculateStatisticsTool,
PerformGroupbyTool,
DetectOutliersTool,
CalculateTrendTool
]
# Feature: true-ai-agent, Property 10: 工具接口一致性
@given(tool_class=st.sampled_from(ALL_TOOLS))
@settings(max_examples=20)
def test_tool_interface_consistency(tool_class):
"""
属性 10对于任何工具它应该实现标准接口name, description, parameters,
execute, is_applicable并且 execute 方法应该接受 DataFrame 和参数,
返回字典格式的聚合结果。
验证需求FR-4.1
"""
# 创建工具实例
tool = tool_class()
# 验证:工具应该是 AnalysisTool 的子类
assert isinstance(tool, AnalysisTool), f"{tool_class.__name__} 不是 AnalysisTool 的子类"
# 验证:工具应该有 name 属性,且返回字符串
assert hasattr(tool, 'name'), f"{tool_class.__name__} 缺少 name 属性"
assert isinstance(tool.name, str), f"{tool_class.__name__}.name 不是字符串"
assert len(tool.name) > 0, f"{tool_class.__name__}.name 是空字符串"
# 验证:工具应该有 description 属性,且返回字符串
assert hasattr(tool, 'description'), f"{tool_class.__name__} 缺少 description 属性"
assert isinstance(tool.description, str), f"{tool_class.__name__}.description 不是字符串"
assert len(tool.description) > 0, f"{tool_class.__name__}.description 是空字符串"
# 验证:工具应该有 parameters 属性,且返回字典
assert hasattr(tool, 'parameters'), f"{tool_class.__name__} 缺少 parameters 属性"
assert isinstance(tool.parameters, dict), f"{tool_class.__name__}.parameters 不是字典"
# 验证parameters 应该符合 JSON Schema 格式
params = tool.parameters
assert 'type' in params, f"{tool_class.__name__}.parameters 缺少 'type' 字段"
assert params['type'] == 'object', f"{tool_class.__name__}.parameters.type 不是 'object'"
# 验证:工具应该有 execute 方法
assert hasattr(tool, 'execute'), f"{tool_class.__name__} 缺少 execute 方法"
assert callable(tool.execute), f"{tool_class.__name__}.execute 不可调用"
# 验证:工具应该有 is_applicable 方法
assert hasattr(tool, 'is_applicable'), f"{tool_class.__name__} 缺少 is_applicable 方法"
assert callable(tool.is_applicable), f"{tool_class.__name__}.is_applicable 不可调用"
# 验证execute 方法应该接受 DataFrame 和关键字参数
# 创建一个简单的测试 DataFrame
test_df = pd.DataFrame({
'col_0': [1, 2, 3, 4, 5],
'col_1': ['a', 'b', 'c', 'd', 'e']
})
# 尝试调用 execute可能会失败但不应该因为签名问题
try:
# 使用空参数调用(可能会因为缺少必需参数而失败,这是预期的)
result = tool.execute(test_df)
except (KeyError, ValueError, TypeError) as e:
# 这些异常是可以接受的(参数验证失败)
pass
# 验证execute 方法应该返回字典
# 我们需要提供有效的参数来测试返回类型
# 根据工具类型提供适当的参数
if tool.name == 'get_column_distribution':
result = tool.execute(test_df, column='col_0')
elif tool.name == 'get_value_counts':
result = tool.execute(test_df, column='col_0')
elif tool.name == 'calculate_statistics':
result = tool.execute(test_df, column='col_0')
elif tool.name == 'perform_groupby':
result = tool.execute(test_df, group_by='col_1')
elif tool.name == 'detect_outliers':
result = tool.execute(test_df, column='col_0')
elif tool.name == 'get_correlation':
test_df_numeric = pd.DataFrame({
'col_0': [1, 2, 3, 4, 5],
'col_1': [2, 4, 6, 8, 10]
})
result = tool.execute(test_df_numeric)
elif tool.name == 'get_time_series':
test_df_time = pd.DataFrame({
'time': pd.date_range('2020-01-01', periods=5),
'value': [1, 2, 3, 4, 5]
})
result = tool.execute(test_df_time, time_column='time')
elif tool.name == 'calculate_trend':
test_df_trend = pd.DataFrame({
'time': pd.date_range('2020-01-01', periods=5),
'value': [1, 2, 3, 4, 5]
})
result = tool.execute(test_df_trend, time_column='time', value_column='value')
else:
# 未知工具,跳过返回类型验证
return
# 验证:返回值应该是字典
assert isinstance(result, dict), f"{tool_class.__name__}.execute 返回值不是字典,而是 {type(result)}"
# Feature: true-ai-agent, Property 19: 工具输出过滤
@given(
tool_class=st.sampled_from(ALL_TOOLS),
df=dataframe_strategy(min_rows=200, max_rows=500)
)
@settings(max_examples=20, deadline=None)
def test_tool_output_filtering(tool_class, df):
"""
属性 19对于任何工具的执行结果返回的数据应该是聚合后的如统计值、
分组计数、图表数据单次返回的数据行数不应超过100行并且不应包含
完整的原始数据表。
验证需求约束条件5.3
"""
# 创建工具实例
tool = tool_class()
# 确保 DataFrame 有足够的行数来测试过滤
assume(len(df) >= 200)
# 根据工具类型准备适当的参数和数据
result = None
try:
if tool.name == 'get_column_distribution':
# 使用第一列
col_name = df.columns[0]
result = tool.execute(df, column=col_name, top_n=10)
elif tool.name == 'get_value_counts':
col_name = df.columns[0]
result = tool.execute(df, column=col_name)
elif tool.name == 'calculate_statistics':
# 找到数值列
numeric_cols = df.select_dtypes(include=[np.number]).columns
if len(numeric_cols) > 0:
result = tool.execute(df, column=numeric_cols[0])
elif tool.name == 'perform_groupby':
# 使用第一列作为分组列
result = tool.execute(df, group_by=df.columns[0])
elif tool.name == 'detect_outliers':
# 找到数值列
numeric_cols = df.select_dtypes(include=[np.number]).columns
if len(numeric_cols) > 0:
result = tool.execute(df, column=numeric_cols[0])
elif tool.name == 'get_correlation':
# 需要至少两个数值列
numeric_cols = df.select_dtypes(include=[np.number]).columns
if len(numeric_cols) >= 2:
result = tool.execute(df)
elif tool.name == 'get_time_series':
# 创建带时间列的 DataFrame
df_with_time = df.copy()
df_with_time['time_col'] = pd.date_range('2020-01-01', periods=len(df))
result = tool.execute(df_with_time, time_column='time_col')
elif tool.name == 'calculate_trend':
# 创建带时间列和数值列的 DataFrame
numeric_cols = df.select_dtypes(include=[np.number]).columns
if len(numeric_cols) > 0:
df_with_time = df.copy()
df_with_time['time_col'] = pd.date_range('2020-01-01', periods=len(df))
result = tool.execute(df_with_time, time_column='time_col', value_column=numeric_cols[0])
except (KeyError, ValueError, TypeError) as e:
# 工具可能因为数据不适用而失败,这是可以接受的
# 跳过此测试用例
assume(False)
# 如果没有结果(工具不适用),跳过验证
if result is None:
assume(False)
# 如果结果包含错误,跳过验证(工具正确地拒绝了不适用的数据)
if 'error' in result:
assume(False)
# 验证:结果应该是字典
assert isinstance(result, dict), f"工具 {tool.name} 返回值不是字典"
# 验证:结果不应包含完整的原始数据
# 检查结果中的所有值
def count_data_rows(obj, max_depth=5):
"""递归计数结果中的数据行数"""
if max_depth <= 0:
return 0
if isinstance(obj, list):
# 如果是列表,检查长度
return len(obj)
elif isinstance(obj, dict):
# 如果是字典,递归检查所有值
max_count = 0
for value in obj.values():
count = count_data_rows(value, max_depth - 1)
max_count = max(max_count, count)
return max_count
else:
return 0
# 计算结果中的最大数据行数
max_rows_in_result = count_data_rows(result)
# 验证单次返回的数据行数不应超过100行
assert max_rows_in_result <= 100, (
f"工具 {tool.name} 返回了 {max_rows_in_result} 行数据,"
f"超过了100行的限制。原始数据有 {len(df)} 行。"
)
# 验证:结果应该是聚合数据,而不是原始数据
# 检查结果的大小是否明显小于原始数据
# 聚合结果的行数应该远小于原始数据行数
if max_rows_in_result > 0:
compression_ratio = max_rows_in_result / len(df)
# 聚合结果应该至少压缩到原始数据的60%以下
# 对于200+行的数据,聚合结果应该显著更小)
# 注意时间序列工具可能返回最多100个数据点所以对于200行数据压缩比是50%
assert compression_ratio <= 0.6, (
f"工具 {tool.name} 的输出压缩比 {compression_ratio:.2%} 太高,"
f"可能返回了过多的原始数据而不是聚合结果"
)
# 验证:结果应该包含聚合信息而不是原始行数据
# 检查结果中是否包含典型的聚合字段
aggregation_indicators = [
'count', 'sum', 'mean', 'median', 'std', 'min', 'max',
'distribution', 'groups', 'correlation', 'statistics',
'time_series', 'aggregation', 'value_counts'
]
has_aggregation = any(
indicator in str(result).lower()
for indicator in aggregation_indicators
)
# 如果结果有数据,应该包含聚合指标
if max_rows_in_result > 0:
assert has_aggregation, (
f"工具 {tool.name} 的结果似乎不包含聚合信息,"
f"可能返回了原始数据而不是聚合结果"
)
# Feature: true-ai-agent, Property 9: 工具选择适配性
@given(data_profile=data_profile_strategy())
@settings(max_examples=20)
def test_tool_selection_adaptability(data_profile):
"""
属性 9对于任何数据画像工具管理器选择的工具集应该与数据特征匹配
包含时间字段时启用时间序列工具,包含分类字段时启用分布分析工具,
包含数值字段时启用统计工具,不包含地理字段时不启用地理工具。
验证需求:工具动态性验收.1, 工具动态性验收.2, FR-4.2
"""
from src.tools.tool_manager import ToolManager
# 创建工具管理器并注册所有工具
registry = ToolRegistry()
for tool_class in ALL_TOOLS:
registry.register(tool_class())
manager = ToolManager(registry)
# 选择工具
selected_tools = manager.select_tools(data_profile)
selected_tool_names = [tool.name for tool in selected_tools]
# 验证:如果包含时间字段,应该启用时间序列工具
has_datetime = any(col.dtype == 'datetime' for col in data_profile.columns)
time_series_tools = ['get_time_series', 'calculate_trend', 'create_line_chart']
if has_datetime:
# 至少应该有一个时间序列工具被选中
has_time_tool = any(tool_name in selected_tool_names for tool_name in time_series_tools)
assert has_time_tool, (
f"数据包含时间字段,但没有选择时间序列工具。"
f"选中的工具:{selected_tool_names}"
)
# 验证:如果包含分类字段,应该启用分布分析工具
has_categorical = any(col.dtype == 'categorical' for col in data_profile.columns)
categorical_tools = ['get_column_distribution', 'get_value_counts', 'perform_groupby',
'create_bar_chart', 'create_pie_chart']
if has_categorical:
# 至少应该有一个分类工具被选中
has_cat_tool = any(tool_name in selected_tool_names for tool_name in categorical_tools)
assert has_cat_tool, (
f"数据包含分类字段,但没有选择分类分析工具。"
f"选中的工具:{selected_tool_names}"
)
# 验证:如果包含数值字段,应该启用统计工具
has_numeric = any(col.dtype == 'numeric' for col in data_profile.columns)
numeric_tools = ['calculate_statistics', 'detect_outliers', 'get_correlation', 'create_heatmap']
if has_numeric:
# 至少应该有一个数值工具被选中
has_num_tool = any(tool_name in selected_tool_names for tool_name in numeric_tools)
assert has_num_tool, (
f"数据包含数值字段,但没有选择统计分析工具。"
f"选中的工具:{selected_tool_names}"
)
# 验证:如果不包含地理字段,不应该启用地理工具
geo_keywords = ['lat', 'lon', 'latitude', 'longitude', 'location', 'address', 'city', 'country']
has_geo = any(
any(keyword in col.name.lower() for keyword in geo_keywords)
for col in data_profile.columns
)
geo_tools = ['create_map_visualization']
if not has_geo:
# 不应该有地理工具被选中
has_geo_tool = any(tool_name in selected_tool_names for tool_name in geo_tools)
assert not has_geo_tool, (
f"数据不包含地理字段,但选择了地理工具。"
f"选中的工具:{selected_tool_names}"
)
# Feature: true-ai-agent, Property 11: 工具适用性判断
@given(
tool_class=st.sampled_from(ALL_TOOLS),
data_profile=data_profile_strategy()
)
@settings(max_examples=20)
def test_tool_applicability_judgment(tool_class, data_profile):
"""
属性 11对于任何工具和数据画像工具的 is_applicable 方法应该正确判断
该工具是否适用于当前数据(例如时间序列工具只适用于包含时间字段的数据)。
验证需求FR-4.3
"""
# 创建工具实例
tool = tool_class()
# 调用 is_applicable 方法
is_applicable = tool.is_applicable(data_profile)
# 验证:返回值应该是布尔值
assert isinstance(is_applicable, bool), (
f"工具 {tool.name} 的 is_applicable 方法返回了非布尔值:{type(is_applicable)}"
)
# 验证:适用性判断应该与数据特征一致
# 根据工具类型检查适用性逻辑
if tool.name in ['get_time_series', 'calculate_trend']:
# 时间序列工具应该只适用于包含时间字段的数据
has_datetime = any(col.dtype == 'datetime' for col in data_profile.columns)
# calculate_trend 还需要数值列
if tool.name == 'calculate_trend':
has_numeric = any(col.dtype == 'numeric' for col in data_profile.columns)
if has_datetime and has_numeric:
# 如果有时间字段和数值字段,工具应该适用
assert is_applicable, (
f"工具 {tool.name} 应该适用于包含时间字段和数值字段的数据,"
f"但 is_applicable 返回 False"
)
else:
# get_time_series 只需要时间字段
if has_datetime:
# 如果有时间字段,工具应该适用
assert is_applicable, (
f"工具 {tool.name} 应该适用于包含时间字段的数据,"
f"但 is_applicable 返回 False"
)
elif tool.name in ['calculate_statistics', 'detect_outliers']:
# 统计工具应该只适用于包含数值字段的数据
has_numeric = any(col.dtype == 'numeric' for col in data_profile.columns)
if has_numeric:
# 如果有数值字段,工具应该适用
assert is_applicable, (
f"工具 {tool.name} 应该适用于包含数值字段的数据,"
f"但 is_applicable 返回 False"
)
elif tool.name == 'get_correlation':
# 相关性工具需要至少两个数值字段
numeric_cols = [col for col in data_profile.columns if col.dtype == 'numeric']
has_enough_numeric = len(numeric_cols) >= 2
if has_enough_numeric:
# 如果有足够的数值字段,工具应该适用
assert is_applicable, (
f"工具 {tool.name} 应该适用于包含至少两个数值字段的数据,"
f"但 is_applicable 返回 False"
)
else:
# 如果数值字段不足,工具不应该适用
assert not is_applicable, (
f"工具 {tool.name} 不应该适用于数值字段少于2个的数据"
f"但 is_applicable 返回 True"
)
elif tool.name == 'create_heatmap':
# 热力图工具需要至少两个数值字段
numeric_cols = [col for col in data_profile.columns if col.dtype == 'numeric']
has_enough_numeric = len(numeric_cols) >= 2
if has_enough_numeric:
# 如果有足够的数值字段,工具应该适用
assert is_applicable, (
f"工具 {tool.name} 应该适用于包含至少两个数值字段的数据,"
f"但 is_applicable 返回 False"
)
else:
# 如果数值字段不足,工具不应该适用
assert not is_applicable, (
f"工具 {tool.name} 不应该适用于数值字段少于2个的数据"
f"但 is_applicable 返回 True"
)
# Feature: true-ai-agent, Property 12: 工具需求识别
@given(data_profile=data_profile_strategy())
@settings(max_examples=20)
def test_tool_requirement_identification(data_profile):
"""
属性 12对于任何分析任务和可用工具集如果任务需要的工具不在可用工具集中
工具管理器应该能够识别缺失的工具并记录需求。
验证需求:工具动态性验收.3, FR-4.2
"""
from src.tools.tool_manager import ToolManager
# 创建一个空的工具注册表(模拟缺失工具的情况)
empty_registry = ToolRegistry()
manager = ToolManager(empty_registry)
# 清空缺失工具列表
manager.clear_missing_tools()
# 尝试选择工具
selected_tools = manager.select_tools(data_profile)
# 获取缺失的工具列表
missing_tools = manager.get_missing_tools()
# 验证:如果数据有特定特征,应该识别出相应的缺失工具
has_datetime = any(col.dtype == 'datetime' for col in data_profile.columns)
has_categorical = any(col.dtype == 'categorical' for col in data_profile.columns)
has_numeric = any(col.dtype == 'numeric' for col in data_profile.columns)
# 如果有时间字段,应该识别出缺失的时间序列工具
if has_datetime:
time_tools = ['get_time_series', 'calculate_trend', 'create_line_chart']
has_missing_time_tool = any(tool in missing_tools for tool in time_tools)
assert has_missing_time_tool, (
f"数据包含时间字段,但没有识别出缺失的时间序列工具。"
f"缺失工具列表:{missing_tools}"
)
# 如果有分类字段,应该识别出缺失的分类工具
if has_categorical:
cat_tools = ['get_column_distribution', 'get_value_counts', 'perform_groupby',
'create_bar_chart', 'create_pie_chart']
has_missing_cat_tool = any(tool in missing_tools for tool in cat_tools)
assert has_missing_cat_tool, (
f"数据包含分类字段,但没有识别出缺失的分类分析工具。"
f"缺失工具列表:{missing_tools}"
)
# 如果有数值字段,应该识别出缺失的统计工具
if has_numeric:
num_tools = ['calculate_statistics', 'detect_outliers', 'get_correlation', 'create_heatmap']
has_missing_num_tool = any(tool in missing_tools for tool in num_tools)
assert has_missing_num_tool, (
f"数据包含数值字段,但没有识别出缺失的统计分析工具。"
f"缺失工具列表:{missing_tools}"
)
# 额外测试:验证所有工具都正确实现了接口
def test_all_tools_implement_interface():
"""验证所有工具类都正确实现了 AnalysisTool 接口。"""
for tool_class in ALL_TOOLS:
tool = tool_class()
# 检查工具是 AnalysisTool 的实例
assert isinstance(tool, AnalysisTool)
# 检查所有必需的方法都存在
assert hasattr(tool, 'name')
assert hasattr(tool, 'description')
assert hasattr(tool, 'parameters')
assert hasattr(tool, 'execute')
assert hasattr(tool, 'is_applicable')
# 检查方法是可调用的
assert callable(tool.execute)
assert callable(tool.is_applicable)
# 额外测试:验证工具注册表功能
def test_tool_registry_with_all_tools():
"""测试 ToolRegistry 与所有工具的正确工作。"""
registry = ToolRegistry()
# 注册所有工具
for tool_class in ALL_TOOLS:
tool = tool_class()
registry.register(tool)
# 验证所有工具都已注册
registered_tools = registry.list_tools()
assert len(registered_tools) == len(ALL_TOOLS)
# 验证我们可以检索每个工具
for tool_class in ALL_TOOLS:
tool = tool_class()
retrieved_tool = registry.get_tool(tool.name)
assert retrieved_tool.name == tool.name
assert isinstance(retrieved_tool, AnalysisTool)

View File

@@ -1,357 +0,0 @@
"""可视化工具的单元测试。"""
import pytest
import pandas as pd
import numpy as np
import os
from pathlib import Path
import tempfile
import shutil
from src.tools.viz_tools import (
CreateBarChartTool,
CreateLineChartTool,
CreatePieChartTool,
CreateHeatmapTool
)
from src.models import DataProfile, ColumnInfo
@pytest.fixture
def temp_output_dir():
"""创建临时输出目录。"""
temp_dir = tempfile.mkdtemp()
yield temp_dir
# 清理
shutil.rmtree(temp_dir, ignore_errors=True)
class TestCreateBarChartTool:
"""测试柱状图工具。"""
def test_basic_functionality(self, temp_output_dir):
"""测试基本功能。"""
tool = CreateBarChartTool()
df = pd.DataFrame({
'category': ['A', 'B', 'C', 'A', 'B', 'A'],
'value': [10, 20, 30, 15, 25, 20]
})
output_path = os.path.join(temp_output_dir, 'bar_chart.png')
result = tool.execute(df, x_column='category', output_path=output_path)
assert result['success'] is True
assert os.path.exists(output_path)
assert result['chart_type'] == 'bar'
assert result['x_column'] == 'category'
def test_with_y_column(self, temp_output_dir):
"""测试指定Y列。"""
tool = CreateBarChartTool()
df = pd.DataFrame({
'category': ['A', 'B', 'C'],
'value': [100, 200, 300]
})
output_path = os.path.join(temp_output_dir, 'bar_chart_y.png')
result = tool.execute(
df,
x_column='category',
y_column='value',
output_path=output_path
)
assert result['success'] is True
assert os.path.exists(output_path)
assert result['y_column'] == 'value'
def test_top_n_limit(self, temp_output_dir):
"""测试 top_n 限制。"""
tool = CreateBarChartTool()
df = pd.DataFrame({
'category': [f'cat_{i}' for i in range(50)],
'value': range(50)
})
output_path = os.path.join(temp_output_dir, 'bar_chart_top.png')
result = tool.execute(
df,
x_column='category',
y_column='value',
top_n=10,
output_path=output_path
)
assert result['success'] is True
assert result['data_points'] == 10
def test_nonexistent_column(self):
"""测试不存在的列。"""
tool = CreateBarChartTool()
df = pd.DataFrame({'col1': [1, 2, 3]})
result = tool.execute(df, x_column='nonexistent')
assert 'error' in result
class TestCreateLineChartTool:
"""测试折线图工具。"""
def test_basic_functionality(self, temp_output_dir):
"""测试基本功能。"""
tool = CreateLineChartTool()
df = pd.DataFrame({
'x': range(10),
'y': [i * 2 for i in range(10)]
})
output_path = os.path.join(temp_output_dir, 'line_chart.png')
result = tool.execute(
df,
x_column='x',
y_column='y',
output_path=output_path
)
assert result['success'] is True
assert os.path.exists(output_path)
assert result['chart_type'] == 'line'
def test_with_datetime(self, temp_output_dir):
"""测试时间序列数据。"""
tool = CreateLineChartTool()
dates = pd.date_range('2020-01-01', periods=20, freq='D')
df = pd.DataFrame({
'date': dates,
'value': range(20)
})
output_path = os.path.join(temp_output_dir, 'line_chart_time.png')
result = tool.execute(
df,
x_column='date',
y_column='value',
output_path=output_path
)
assert result['success'] is True
assert os.path.exists(output_path)
def test_large_dataset_sampling(self, temp_output_dir):
"""测试大数据集采样。"""
tool = CreateLineChartTool()
df = pd.DataFrame({
'x': range(2000),
'y': range(2000)
})
output_path = os.path.join(temp_output_dir, 'line_chart_large.png')
result = tool.execute(
df,
x_column='x',
y_column='y',
output_path=output_path
)
assert result['success'] is True
# 应该被采样到1000个点左右
assert result['data_points'] <= 1000
class TestCreatePieChartTool:
"""测试饼图工具。"""
def test_basic_functionality(self, temp_output_dir):
"""测试基本功能。"""
tool = CreatePieChartTool()
df = pd.DataFrame({
'category': ['A', 'B', 'C', 'A', 'B', 'A']
})
output_path = os.path.join(temp_output_dir, 'pie_chart.png')
result = tool.execute(
df,
column='category',
output_path=output_path
)
assert result['success'] is True
assert os.path.exists(output_path)
assert result['chart_type'] == 'pie'
assert result['categories'] == 3
def test_top_n_with_others(self, temp_output_dir):
"""测试 top_n 并归类其他。"""
tool = CreatePieChartTool()
df = pd.DataFrame({
'category': [f'cat_{i}' for i in range(20)] * 5
})
output_path = os.path.join(temp_output_dir, 'pie_chart_top.png')
result = tool.execute(
df,
column='category',
top_n=5,
output_path=output_path
)
assert result['success'] is True
# 5个类别 + 1个"其他"
assert result['categories'] == 6
class TestCreateHeatmapTool:
"""测试热力图工具。"""
def test_basic_functionality(self, temp_output_dir):
"""测试基本功能。"""
tool = CreateHeatmapTool()
df = pd.DataFrame({
'x': range(10),
'y': [i * 2 for i in range(10)],
'z': [i * 3 for i in range(10)]
})
output_path = os.path.join(temp_output_dir, 'heatmap.png')
result = tool.execute(df, output_path=output_path)
assert result['success'] is True
assert os.path.exists(output_path)
assert result['chart_type'] == 'heatmap'
assert len(result['columns']) == 3
def test_with_specific_columns(self, temp_output_dir):
"""测试指定列。"""
tool = CreateHeatmapTool()
df = pd.DataFrame({
'a': range(10),
'b': range(10, 20),
'c': range(20, 30),
'd': range(30, 40)
})
output_path = os.path.join(temp_output_dir, 'heatmap_cols.png')
result = tool.execute(
df,
columns=['a', 'b', 'c'],
output_path=output_path
)
assert result['success'] is True
assert len(result['columns']) == 3
assert 'd' not in result['columns']
def test_insufficient_columns(self):
"""测试列数不足。"""
tool = CreateHeatmapTool()
df = pd.DataFrame({'x': range(10)})
result = tool.execute(df)
assert 'error' in result
class TestVisualizationToolsApplicability:
"""测试可视化工具的适用性判断。"""
def test_bar_chart_applicability(self):
"""测试柱状图适用性。"""
tool = CreateBarChartTool()
profile = DataProfile(
file_path='test.csv',
row_count=100,
column_count=1,
columns=[
ColumnInfo(name='cat', dtype='categorical', missing_rate=0.0, unique_count=5)
],
inferred_type='unknown'
)
assert tool.is_applicable(profile) is True
def test_line_chart_applicability(self):
"""测试折线图适用性。"""
tool = CreateLineChartTool()
# 包含数值列
profile_numeric = DataProfile(
file_path='test.csv',
row_count=100,
column_count=1,
columns=[
ColumnInfo(name='value', dtype='numeric', missing_rate=0.0, unique_count=50)
],
inferred_type='unknown'
)
assert tool.is_applicable(profile_numeric) is True
# 不包含数值列
profile_text = DataProfile(
file_path='test.csv',
row_count=100,
column_count=1,
columns=[
ColumnInfo(name='text', dtype='text', missing_rate=0.0, unique_count=50)
],
inferred_type='unknown'
)
assert tool.is_applicable(profile_text) is False
def test_heatmap_applicability(self):
"""测试热力图适用性。"""
tool = CreateHeatmapTool()
# 包含至少两个数值列
profile_sufficient = DataProfile(
file_path='test.csv',
row_count=100,
column_count=2,
columns=[
ColumnInfo(name='x', dtype='numeric', missing_rate=0.0, unique_count=50),
ColumnInfo(name='y', dtype='numeric', missing_rate=0.0, unique_count=50)
],
inferred_type='unknown'
)
assert tool.is_applicable(profile_sufficient) is True
# 只有一个数值列
profile_insufficient = DataProfile(
file_path='test.csv',
row_count=100,
column_count=1,
columns=[
ColumnInfo(name='x', dtype='numeric', missing_rate=0.0, unique_count=50)
],
inferred_type='unknown'
)
assert tool.is_applicable(profile_insufficient) is False
class TestVisualizationErrorHandling:
"""测试可视化工具的错误处理。"""
def test_invalid_output_path(self):
"""测试无效的输出路径。"""
tool = CreateBarChartTool()
df = pd.DataFrame({'cat': ['A', 'B', 'C']})
# 使用无效路径(只读目录等)
# 注意:这个测试可能在某些系统上不会失败
result = tool.execute(
df,
x_column='cat',
output_path='/invalid/path/chart.png'
)
# 应该返回错误或成功创建目录
assert 'error' in result or result['success'] is True
def test_empty_dataframe(self):
"""测试空 DataFrame。"""
tool = CreateBarChartTool()
df = pd.DataFrame()
result = tool.execute(df, x_column='nonexistent')
assert 'error' in result