Add web session analysis platform with follow-up topics

This commit is contained in:
2026-03-09 22:23:00 +08:00
commit 17ce711e49
30 changed files with 10681 additions and 0 deletions

9
.gitignore vendored Normal file
View File

@@ -0,0 +1,9 @@
__pycache__/
*.pyc
.DS_Store
.env
.env copy
outputs/
runtime/
*.log
log.txt

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2025 Data Analysis Agent Team
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

357
README.md Normal file
View File

@@ -0,0 +1,357 @@
# 数据分析智能体 (Data Analysis Agent)
🤖 **基于LLM的智能数据分析代理**
[![Python Version](https://img.shields.io/badge/python-3.8%2B-blue.svg)](https://python.org)
[![License](https://img.shields.io/badge/license-MIT-green.svg)](LICENSE)
[![OpenAI](https://img.shields.io/badge/LLM-OpenAI%20Compatible-orange.svg)](https://openai.com)
## 📋 项目简介
![alt text](assets/images/40d04b1dc21848cf9eeac4b50551f2a1.png)
![alt text](assets/images/d24d6dd97279a27fd8c9d652bac1fdb2.png)
数据分析智能体是一个功能强大的Python工具它结合了大语言模型(LLM)的理解能力和Python数据分析库的计算能力能够
- 🎯 **自然语言分析**:接受用户的自然语言需求,自动生成专业的数据分析代码
- 📊 **智能可视化**:自动生成高质量的图表,支持中文显示,输出到专用目录
- 🔄 **迭代优化**:基于执行结果自动调整分析策略,持续优化分析质量
- 📝 **报告生成**:自动生成包含图表和分析结论的专业报告(Markdown + Word)
- 🛡️ **安全执行**:在受限的环境中安全执行代码,支持常用的数据分析库
## 🏗️ 项目架构
```
data_analysis_agent/
├── 📁 config/ # 配置管理
│ ├── __init__.py
│ └── llm_config.py # LLM配置(API密钥、模型等)
├── 📁 utils/ # 核心工具模块
│ ├── code_executor.py # 安全的代码执行器
│ ├── llm_helper.py # LLM调用辅助类
│ ├── fallback_openai_client.py # 支持故障转移的OpenAI客户端
│ ├── extract_code.py # 代码提取工具
│ ├── format_execution_result.py # 执行结果格式化
│ └── create_session_dir.py # 会话目录管理
├── 📄 data_analysis_agent.py # 主智能体类
├── 📄 prompts.py # 系统提示词模板
├── 📄 main.py # 使用示例
├── 📄 requirements.txt # 项目依赖
├── 📄 .env # 环境变量配置
└── 📁 outputs/ # 分析结果输出目录
└── session_[时间戳]/ # 每次分析的独立会话目录
├── *.png # 生成的图表
├── 最终分析报告.md # Markdown报告
└── 最终分析报告.docx # Word报告
```
## 📊 数据分析流程图
使用Mermaid图表展示完整的数据分析流程
```mermaid
graph TD
A[用户输入自然语言需求] --> B[初始化智能体]
B --> C[创建专用会话目录]
C --> D[LLM理解需求并生成代码]
D --> E[安全代码执行器执行]
E --> F{执行是否成功?}
F -->|失败| G[错误分析与修复]
G --> D
F -->|成功| H[结果格式化与存储]
H --> I{是否需要更多分析?}
I -->|是| J[基于当前结果继续分析]
J --> D
I -->|否| K[收集所有图表]
K --> L[生成最终分析报告]
L --> M[输出Markdown和Word报告]
M --> N[分析完成]
style A fill:#e1f5fe
style N fill:#c8e6c9
style F fill:#fff3e0
style I fill:#fff3e0
```
## 🔄 智能体工作流程
```mermaid
sequenceDiagram
participant User as 用户
participant Agent as 数据分析智能体
participant LLM as 语言模型
participant Executor as 代码执行器
participant Storage as 文件存储
User->>Agent: 提供数据文件和分析需求
Agent->>Storage: 创建专用会话目录
loop 多轮分析循环
Agent->>LLM: 发送分析需求和上下文
LLM->>Agent: 返回分析代码和推理
Agent->>Executor: 执行Python代码
Executor->>Storage: 保存图表文件
Executor->>Agent: 返回执行结果
alt 需要继续分析
Agent->>LLM: 基于结果继续分析
else 分析完成
Agent->>LLM: 生成最终报告
LLM->>Agent: 返回分析报告
Agent->>Storage: 保存报告文件
end
end
Agent->>User: 返回完整分析结果
```
## ✨ 核心特性
### 🧠 智能分析流程
- **多阶段分析**:数据探索 → 清洗检查 → 分析可视化 → 图片收集 → 报告生成
- **错误自愈**:自动检测并修复常见错误(编码、列名、数据类型等)
- **上下文保持**Notebook环境中变量和状态在分析过程中持续保持
### 📋 多格式报告
- **Markdown报告**:结构化的分析报告,包含图表引用
- **Word文档**:专业的文档格式,便于分享和打印
- **图片集成**:报告中自动引用生成的图表
## 🚀 快速开始
### 1. 环境准备
```bash
# 克隆项目
git clone https://github.com/li-xiu-qi/data_analysis_agent.git
cd data_analysis_agent
# 安装依赖
pip install -r requirements.txt
```
### 2. 配置API密钥
创建`.env`文件:
```bash
# OpenAI API配置
OPENAI_API_KEY=your_api_key_here
OPENAI_BASE_URL=https://api.openai.com/v1
OPENAI_MODEL=gpt-4
# 或者使用兼容的API如火山引擎
# OPENAI_BASE_URL=https://ark.cn-beijing.volces.com/api/v3
# OPENAI_MODEL=deepseek-v3-250324
```
### 3. 基本使用
```python
from data_analysis_agent import DataAnalysisAgent
from config.llm_config import LLMConfig
# 初始化智能体
llm_config = LLMConfig()
agent = DataAnalysisAgent(llm_config)
# 开始分析
files = ["your_data.csv"]
report = agent.analyze(
user_input="分析销售数据,生成趋势图表和关键指标",
files=files
)
print(report)
```
```python
# 自定义配置
agent = DataAnalysisAgent(
llm_config=llm_config,
output_dir="custom_outputs", # 自定义输出目录
max_rounds=30 # 增加最大分析轮数
)
# 使用便捷函数
from data_analysis_agent import quick_analysis
report = quick_analysis(
query="分析用户行为数据,重点关注转化率",
files=["user_behavior.csv"],
max_rounds=15
)
```
## 📊 使用示例
以下是分析贵州茅台财务数据的完整示例:
```python
# 示例:茅台财务分析
files = ["贵州茅台利润表.csv"]
report = agent.analyze(
user_input="基于贵州茅台的数据,输出五个重要的统计指标,并绘制相关图表。最后生成汇报给我。",
files=files
)
```
**生成的分析内容包括:**
- 📈 营业总收入趋势图
- 💰 净利润率变化分析
- 📊 利润构成分析图表
- 💵 每股收益变化趋势
- 📋 营业成本占比分析
- 📄 综合分析报告
## 🎨 流程可视化
### 📊 分析过程状态图
```mermaid
stateDiagram-v2
[*] --> 数据加载
数据加载 --> 数据探索: 成功加载
数据加载 --> 编码修复: 编码错误
编码修复 --> 数据探索: 修复完成
数据探索 --> 数据清洗: 探索完成
数据清洗 --> 统计分析: 清洗完成
统计分析 --> 可视化生成: 分析完成
可视化生成 --> 图表保存: 图表生成
图表保存 --> 结果评估: 保存完成
结果评估 --> 继续分析: 需要更多分析
结果评估 --> 报告生成: 分析充分
继续分析 --> 统计分析
报告生成 --> [*]: 完成
```
## 🔧 配置选项
### LLM配置
```python
@dataclass
class LLMConfig:
provider: str = "openai"
api_key: str = os.environ.get("OPENAI_API_KEY", "")
base_url: str = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")
model: str = os.environ.get("OPENAI_MODEL", "gpt-4")
max_tokens: int = 4000
temperature: float = 0.1
```
### 执行器配置
```python
# 允许的库列表
ALLOWED_IMPORTS = {
'pandas', 'numpy', 'matplotlib', 'duckdb',
'scipy', 'sklearn', 'plotly', 'requests',
'os', 'json', 'datetime', 're', 'pathlib'
}
```
## 🎯 最佳实践
### 1. 数据准备
- ✅ 使用CSV格式支持UTF-8/GBK编码
- ✅ 确保列名清晰、无特殊字符
- ✅ 数据量适中(建议<100MB
### 2. 查询编写
- ✅ 使用清晰的中文描述分析需求
- ✅ 指定想要的图表类型和关键指标
- ✅ 明确分析的目标和重点
### 3. 结果解读
- ✅ 检查生成的图表是否符合预期
- ✅ 阅读分析报告中的关键发现
- ✅ 根据需要调整查询重新分析
## 🚨 注意事项
### 安全限制
- 🔒 仅支持预定义的数据分析库
- 🔒 不允许文件系统操作(除图片保存)
- 🔒 不支持网络请求除LLM调用
### 性能考虑
- ⚡ 大数据集可能导致分析时间较长
- ⚡ 复杂分析任务可能需要多轮交互
- ⚡ API调用频率受到模型限制
### 兼容性
- 🐍 Python 3.8+
- 📊 支持pandas兼容的数据格式
- 🖼️ 需要matplotlib中文字体支持
## 🐛 故障排除
### 常见问题
**Q: 图表中文显示为方框?**
A: 系统会自动检测并使用可用的中文字体macOS: Hiragino Sans GB, Songti SC等Windows: SimHei等
**Q: API调用失败**
A: 检查`.env`文件中的API密钥和端点配置确保网络连接正常。
**Q: 数据加载错误?**
A: 检查文件路径和编码格式支持UTF-8、GBK等常见编码。
**Q: 分析结果不准确?**
A: 尝试提供更详细的分析需求,或检查原始数据质量。
**Q: Mermaid流程图无法正常显示**
A: 确保在支持Mermaid的环境中查看如GitHub、Typora、VS Code预览等。如果在本地查看推荐使用支持Mermaid的Markdown编辑器。
**Q: 如何自定义流程图样式?**
A: 可以在Mermaid代码块中添加样式定义或使用不同的图表类型graph、flowchart、sequenceDiagram等来满足不同的展示需求。
### 错误日志
分析过程中的错误信息会保存在会话目录中,便于调试和优化。
## 🤝 贡献指南
欢迎贡献代码和改进建议!
1. Fork 项目
2. 创建功能分支
3. 提交更改
4. 推送到分支
5. 创建Pull Request
## 📄 许可证
本项目基于MIT许可证开源。详见[LICENSE](LICENSE)文件。
## 🔄 更新日志
### v1.0.0
- ✨ 初始版本发布
- 🎯 支持自然语言数据分析
- 📊 集成matplotlib图表生成
- 📝 自动报告生成功能
- 🔒 安全的代码执行环境
---
<div align="center">
**🚀 让数据分析变得更智能、更简单!**
</div>

5901
UB IOV Support_TR.csv Executable file

File diff suppressed because it is too large Load Diff

54
__init__.py Normal file
View File

@@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
"""
Data Analysis Agent Package
一个基于LLM的智能数据分析代理专门为Jupyter Notebook环境设计。
"""
from .core.notebook_agent import NotebookAgent
from .config.llm_config import LLMConfig
from .utils.code_executor import CodeExecutor
__version__ = "1.0.0"
__author__ = "Data Analysis Agent Team"
# 主要导出类
__all__ = [
"NotebookAgent",
"LLMConfig",
"CodeExecutor",
]
# 便捷函数
def create_agent(config=None, output_dir="outputs", max_rounds=20, session_dir=None):
"""
创建一个数据分析智能体实例
Args:
config: LLM配置如果为None则使用默认配置
output_dir: 输出目录
max_rounds: 最大分析轮数
session_dir: 指定会话目录(可选)
Returns:
NotebookAgent: 智能体实例
"""
if config is None:
config = LLMConfig()
return NotebookAgent(config=config, output_dir=output_dir, max_rounds=max_rounds, session_dir=session_dir)
def quick_analysis(query, files=None, output_dir="outputs", max_rounds=10):
"""
快速数据分析函数
Args:
query: 分析需求(自然语言)
files: 数据文件路径列表
output_dir: 输出目录
max_rounds: 最大分析轮数
Returns:
dict: 分析结果
"""
agent = create_agent(output_dir=output_dir, max_rounds=max_rounds)
return agent.analyze(query, files)

8
config/__init__.py Normal file
View File

@@ -0,0 +1,8 @@
# -*- coding: utf-8 -*-
"""
配置模块
"""
from .llm_config import LLMConfig
__all__ = ['LLMConfig']

44
config/llm_config.py Normal file
View File

@@ -0,0 +1,44 @@
# -*- coding: utf-8 -*-
"""
配置管理模块
"""
import os
from typing import Dict, Any
from dataclasses import dataclass, asdict
from dotenv import load_dotenv
load_dotenv()
@dataclass
class LLMConfig:
"""LLM配置"""
provider: str = "openai" # openai, anthropic, etc.
api_key: str = os.environ.get("OPENAI_API_KEY", "")
base_url: str = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")
model: str = os.environ.get("OPENAI_MODEL", "gpt-4o-mini")
temperature: float = 0.3
max_tokens: int = 131072
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return asdict(self)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "LLMConfig":
"""从字典创建配置"""
return cls(**data)
def validate(self) -> bool:
"""验证配置有效性"""
if not self.api_key:
raise ValueError("OPENAI_API_KEY is required")
if not self.base_url:
raise ValueError("OPENAI_BASE_URL is required")
if not self.model:
raise ValueError("OPENAI_MODEL is required")
return True

628
data_analysis_agent.py Normal file
View File

@@ -0,0 +1,628 @@
# -*- coding: utf-8 -*-
"""
简化的 Notebook 数据分析智能体
仅包含用户和助手两个角
2. 图片必须保存到指定的会话目录中输出绝对路径禁止使用plt.show()
3. 表格输出控制超过15行只显示前5行和后5行
4. 强制使用SimHei字体plt.rcParams['font.sans-serif'] = ['SimHei']
5. 输出格式严格使用YAML共享上下文的单轮对话模式
"""
import os
import json
import yaml
from typing import Dict, Any, List, Optional
from utils.create_session_dir import create_session_output_dir
from utils.format_execution_result import format_execution_result
from utils.extract_code import extract_code_from_response
from utils.data_loader import load_and_profile_data
from utils.llm_helper import LLMHelper, LLMCallError
from utils.execution_session_client import (
ExecutionSessionClient,
WorkerSessionError,
WorkerTimeoutError,
)
from config.llm_config import LLMConfig
from prompts import data_analysis_system_prompt, final_report_system_prompt
class DataAnalysisAgent:
"""
数据分析智能体
职责:
- 接收用户自然语言需求
- 生成Python分析代码
- 执行代码并收集结果
- 基于执行结果继续生成后续分析代码
"""
def __init__(
self,
llm_config: LLMConfig = None,
output_dir: str = "outputs",
max_rounds: int = 20,
force_max_rounds: bool = False,
):
"""
初始化智能体
Args:
config: LLM配置
output_dir: 输出目录
max_rounds: 最大对话轮数
force_max_rounds: 是否强制运行到最大轮数忽略AI的完成信号
"""
self.config = llm_config or LLMConfig()
self.llm = LLMHelper(self.config)
self.base_output_dir = output_dir
self.max_rounds = max_rounds
self.force_max_rounds = force_max_rounds
# 对话历史和上下文
self.conversation_history = []
self.analysis_results = []
self.current_round = 0
self.session_output_dir = None
self.executor = None
self.data_profile = "" # 存储数据画像
self.fatal_error = ""
self.fatal_error_stage = ""
self.session_files = []
self.template_content = "未提供特定模板,请根据数据画像自主发挥。"
def _process_response(self, response: str) -> Dict[str, Any]:
"""
统一处理LLM响应判断行动类型并执行相应操作
Args:
response: LLM的响应内容
Returns:
处理结果字典
"""
try:
yaml_data = self.llm.parse_yaml_response(response)
action = yaml_data.get("action", "generate_code")
print(f"🎯 检测到动作: {action}")
if action == "analysis_complete":
return self._handle_analysis_complete(response, yaml_data)
elif action == "collect_figures":
return self._handle_collect_figures(response, yaml_data)
elif action == "generate_code":
return self._handle_generate_code(response, yaml_data)
else:
print(f"⚠️ 未知动作类型: {action}按generate_code处理")
return self._handle_generate_code(response, yaml_data)
except Exception as e:
print(f"⚠️ 解析响应失败: {str(e)}按generate_code处理")
return self._handle_generate_code(response, {})
def _handle_analysis_complete(
self, response: str, yaml_data: Dict[str, Any]
) -> Dict[str, Any]:
"""处理分析完成动作"""
print("✅ 分析任务完成")
final_report = yaml_data.get("final_report", "分析完成,无最终报告")
return {
"action": "analysis_complete",
"final_report": final_report,
"response": response,
"continue": False,
}
def _handle_collect_figures(
self, response: str, yaml_data: Dict[str, Any]
) -> Dict[str, Any]:
"""处理图片收集动作"""
print("📊 开始收集图片")
figures_to_collect = yaml_data.get("figures_to_collect", [])
collected_figures = []
for figure_info in figures_to_collect:
figure_number = figure_info.get("figure_number", "未知")
# 确保figure_number不为None时才用于文件名
if figure_number != "未知":
default_filename = f"figure_{figure_number}.png"
else:
default_filename = "figure_unknown.png"
filename = figure_info.get("filename", default_filename)
file_path = figure_info.get("file_path", "") # 获取具体的文件路径
description = figure_info.get("description", "")
analysis = figure_info.get("analysis", "")
print(f"📈 收集图片 {figure_number}: {filename}")
print(f" 📂 路径: {file_path}")
print(f" 📝 描述: {description}")
print(f" 🔍 分析: {analysis}")
# 验证文件是否存在
# 只有文件真正存在时才加入列表,防止报告出现裂图
if file_path and os.path.exists(file_path):
print(f" ✅ 文件存在: {file_path}")
# 记录图片信息
collected_figures.append(
{
"figure_number": figure_number,
"filename": filename,
"file_path": file_path,
"description": description,
"analysis": analysis,
}
)
else:
if file_path:
print(f" ⚠️ 文件不存在: {file_path}")
else:
print(f" ⚠️ 未提供文件路径")
return {
"action": "collect_figures",
"collected_figures": collected_figures,
"response": response,
"continue": True,
}
def _handle_generate_code(
self, response: str, yaml_data: Dict[str, Any]
) -> Dict[str, Any]:
"""处理代码生成和执行动作"""
# 从YAML数据中获取代码更准确
code = yaml_data.get("code", "")
# 如果YAML中没有代码尝试从响应中提取
if not code:
code = extract_code_from_response(response)
# 二次清洗防止YAML中解析出的code包含markdown标记
if code:
code = code.strip()
if code.startswith("```"):
import re
# 去除开头的 ```python 或 ```
code = re.sub(r"^```[a-zA-Z]*\n", "", code)
# 去除结尾的 ```
code = re.sub(r"\n```$", "", code)
code = code.strip()
if code:
print(f"🔧 执行代码:\n{code}")
print("-" * 40)
# 执行代码
result = self.executor.execute_code(code)
# 格式化执行结果
feedback = format_execution_result(result)
print(f"📋 执行反馈:\n{feedback}")
return {
"action": "generate_code",
"code": code,
"result": result,
"feedback": feedback,
"response": response,
"continue": True,
}
else:
# 如果没有代码说明LLM响应格式有问题需要重新生成
print("⚠️ 未从响应中提取到可执行代码要求LLM重新生成")
return {
"action": "invalid_response",
"error": "响应中缺少可执行代码",
"response": response,
"continue": True,
}
def analyze(
self,
user_input: str,
files: List[str] = None,
template_path: str = None,
session_output_dir: str = None,
reset_context: bool = True,
keep_session_open: bool = False,
) -> Dict[str, Any]:
"""
开始分析流程
Args:
user_input: 用户的自然语言需求
files: 数据文件路径列表
template_path: 参考模板路径(可选)
session_output_dir: 指定的会话输出目录(可选)
Returns:
分析结果字典
"""
files = files or []
self.current_round = 0
self.fatal_error = ""
self.fatal_error_stage = ""
if reset_context or not self.executor:
self.conversation_history = []
self.analysis_results = []
self._initialize_session(user_input, files, template_path, session_output_dir)
initial_prompt = self._build_initial_user_prompt(user_input, files)
else:
self._validate_followup_files(files)
if template_path:
self.template_content = self._load_template_content(template_path)
initial_prompt = self._build_followup_user_prompt(user_input)
print(f"🚀 开始数据分析任务")
print(f"📝 用户需求: {user_input}")
if files:
print(f"📁 数据文件: {', '.join(files)}")
if template_path:
print(f"📄 参考模板: {template_path}")
print(f"📂 输出目录: {self.session_output_dir}")
print(f"🔢 最大轮数: {self.max_rounds}")
# 添加到对话历史
self.conversation_history.append({"role": "user", "content": initial_prompt})
try:
while self.current_round < self.max_rounds:
self.current_round += 1
print(f"\n🔄 第 {self.current_round} 轮分析")
try:
# 获取当前执行环境的变量信息
notebook_variables = self.executor.get_environment_info()
# 格式化系统提示词,填入动态变量
# 注意prompts.py 中的模板现在需要填充三个变量
formatted_system_prompt = data_analysis_system_prompt.format(
notebook_variables=notebook_variables,
user_input=user_input,
template_content=self.template_content
)
response = self.llm.call(
prompt=self._build_conversation_prompt(),
system_prompt=formatted_system_prompt,
)
print(f"🤖 助手响应:\n{response}")
# 使用统一的响应处理方法
process_result = self._process_response(response)
# 根据处理结果决定是否继续(仅在非强制模式下)
if not self.force_max_rounds and not process_result.get(
"continue", True
):
print(f"\n✅ 分析完成!")
break
# 添加到对话历史
self.conversation_history.append(
{"role": "assistant", "content": response}
)
# 根据动作类型添加不同的反馈
if process_result["action"] == "generate_code":
feedback = process_result.get("feedback", "")
self.conversation_history.append(
{"role": "user", "content": f"代码执行反馈:\n{feedback}"}
)
# 记录分析结果
self.analysis_results.append(
{
"round": self.current_round,
"code": process_result.get("code", ""),
"result": process_result.get("result", {}),
"response": response,
}
)
elif process_result["action"] == "collect_figures":
# 记录图片收集结果
collected_figures = process_result.get("collected_figures", [])
missing_figures = process_result.get("missing_figures", [])
feedback = f"已收集 {len(collected_figures)} 个有效图片及其分析。"
if missing_figures:
feedback += f"\n⚠️ 以下图片未找到,请检查代码是否成功保存了这些图片: {missing_figures}"
self.conversation_history.append(
{
"role": "user",
"content": f"图片收集反馈:\n{feedback}\n请继续下一步分析。",
}
)
# 记录到分析结果中
self.analysis_results.append(
{
"round": self.current_round,
"action": "collect_figures",
"collected_figures": collected_figures,
"missing_figures": missing_figures,
"response": response,
}
)
except LLMCallError as e:
error_msg = str(e)
self.fatal_error = error_msg
self.fatal_error_stage = "模型调用"
print(f"{error_msg}")
break
except WorkerTimeoutError as e:
error_msg = str(e)
self.fatal_error = error_msg
self.fatal_error_stage = "代码执行超时"
print(f"{error_msg}")
break
except WorkerSessionError as e:
error_msg = str(e)
self.fatal_error = error_msg
self.fatal_error_stage = "执行子进程异常"
print(f"{error_msg}")
break
except Exception as e:
error_msg = f"分析流程错误: {str(e)}"
print(f"{error_msg}")
self.conversation_history.append(
{
"role": "user",
"content": f"发生错误: {error_msg},请重新生成代码。",
}
)
# 生成最终总结
if self.current_round >= self.max_rounds:
print(f"\n⚠️ 已达到最大轮数 ({self.max_rounds}),分析结束")
return self._generate_final_report()
finally:
if self.executor and not keep_session_open:
self.executor.close()
self.executor = None
def _build_conversation_prompt(self) -> str:
"""构建对话提示词"""
prompt_parts = []
for msg in self.conversation_history:
role = msg["role"]
content = msg["content"]
if role == "user":
prompt_parts.append(f"用户: {content}")
else:
prompt_parts.append(f"助手: {content}")
return "\n\n".join(prompt_parts)
def _generate_final_report(self) -> Dict[str, Any]:
"""生成最终分析报告"""
if self.fatal_error:
final_report_content = (
"# 分析任务失败\n\n"
f"- 失败阶段: 第 {self.current_round}{self.fatal_error_stage or '分析流程'}\n"
f"- 错误信息: {self.fatal_error}\n"
"- 建议: 根据失败阶段检查模型配置、执行子进程日志和数据文件可用性。"
)
report_file_path = os.path.join(self.session_output_dir, "最终分析报告.md")
try:
with open(report_file_path, "w", encoding="utf-8") as f:
f.write(final_report_content)
print(f"📄 失败报告已保存至: {report_file_path}")
except Exception as e:
print(f"❌ 保存失败报告文件失败: {str(e)}")
return {
"session_output_dir": self.session_output_dir,
"total_rounds": self.current_round,
"analysis_results": self.analysis_results,
"collected_figures": [],
"conversation_history": self.conversation_history,
"final_report": final_report_content,
"report_file_path": report_file_path,
}
# 收集所有生成的图片信息
all_figures = []
for result in self.analysis_results:
if result.get("action") == "collect_figures":
all_figures.extend(result.get("collected_figures", []))
print(f"\n📊 开始生成最终分析报告...")
print(f"📂 输出目录: {self.session_output_dir}")
print(f"🔢 总轮数: {self.current_round}")
print(f"📈 收集图片: {len(all_figures)}")
# 构建用于生成最终报告的提示词
final_report_prompt = self._build_final_report_prompt(all_figures)
try: # 调用LLM生成最终报告
response = self.llm.call(
prompt=final_report_prompt,
system_prompt="你将会接收到一个数据分析任务的最终报告请求,请根据提供的分析结果和图片信息生成完整的分析报告。",
max_tokens=16384, # 设置较大的token限制以容纳完整报告
)
# 解析响应,提取最终报告
try:
# 尝试解析YAML
yaml_data = self.llm.parse_yaml_response(response)
# 情况1: 标准YAML格式包含 action: analysis_complete
if yaml_data.get("action") == "analysis_complete":
final_report_content = yaml_data.get("final_report", response)
# 情况2: 解析成功但没字段,或者解析失败
else:
# 如果内容看起来像Markdown报告包含标题直接使用
if "# " in response or "## " in response:
print("⚠️ 未检测到标准YAML动作但内容疑似Markdown报告直接采纳")
final_report_content = response
else:
final_report_content = "LLM未返回有效报告内容"
except Exception as e:
# 解析完全失败,直接使用原始响应
print(f"⚠️ YAML解析失败 ({e}),直接使用原始响应作为报告")
final_report_content = response
print("✅ 最终报告生成完成")
except Exception as e:
print(f"❌ 生成最终报告时出错: {str(e)}")
final_report_content = f"报告生成失败: {str(e)}"
# 保存最终报告到文件
report_file_path = os.path.join(self.session_output_dir, "最终分析报告.md")
try:
with open(report_file_path, "w", encoding="utf-8") as f:
f.write(final_report_content)
print(f"📄 最终报告已保存至: {report_file_path}")
except Exception as e:
print(f"❌ 保存报告文件失败: {str(e)}")
# 返回完整的分析结果
return {
"session_output_dir": self.session_output_dir,
"total_rounds": self.current_round,
"analysis_results": self.analysis_results,
"collected_figures": all_figures,
"conversation_history": self.conversation_history,
"final_report": final_report_content,
"report_file_path": report_file_path,
}
def _build_final_report_prompt(self, all_figures: List[Dict[str, Any]]) -> str:
"""构建用于生成最终报告的提示词"""
# 构建图片信息摘要,使用相对路径
figures_summary = ""
if all_figures:
figures_summary = "\n生成的图片及分析:\n"
for i, figure in enumerate(all_figures, 1):
filename = figure.get("filename", "未知文件名")
# 使用相对路径格式,适合在报告中引用
relative_path = f"./{filename}"
figures_summary += f"{i}. {filename}\n"
figures_summary += f" 相对路径: {relative_path}\n"
figures_summary += f" 描述: {figure.get('description', '无描述')}\n"
figures_summary += f" 分析: {figure.get('analysis', '无分析')}\n\n"
else:
figures_summary = "\n本次分析未生成图片。\n"
# 构建代码执行结果摘要(仅包含成功执行的代码块)
code_results_summary = ""
success_code_count = 0
for result in self.analysis_results:
if result.get("action") != "collect_figures" and result.get("code"):
exec_result = result.get("result", {})
if exec_result.get("success"):
success_code_count += 1
code_results_summary += f"代码块 {success_code_count}: 执行成功\n"
if exec_result.get("output"):
code_results_summary += (
f"输出: {exec_result.get('output')[:]}\n\n"
)
# 使用 prompts.py 中的统一提示词模板,并添加相对路径使用说明
prompt = final_report_system_prompt.format(
current_round=self.current_round,
session_output_dir=self.session_output_dir,
data_profile=self.data_profile, # 注入数据画像
figures_summary=figures_summary,
code_results_summary=code_results_summary,
)
# 在提示词中明确要求使用相对路径
prompt += """
📁 **图片路径使用说明**
报告和图片都在同一目录下,请在报告中使用相对路径引用图片:
- 格式:![图片描述](./图片文件名.png)
- 示例:![营业总收入趋势](./营业总收入趋势.png)
- 这样可以确保报告在不同环境下都能正确显示图片
"""
return prompt
def reset(self):
"""重置智能体状态"""
self.conversation_history = []
self.analysis_results = []
self.current_round = 0
self.fatal_error = ""
self.fatal_error_stage = ""
self.session_files = []
self.template_content = "未提供特定模板,请根据数据画像自主发挥。"
self.close_session()
def close_session(self):
"""关闭当前分析会话但保留对象实例。"""
if self.executor:
self.executor.close()
self.executor = None
def _initialize_session(
self,
user_input: str,
files: List[str],
template_path: str,
session_output_dir: str,
) -> None:
if session_output_dir:
self.session_output_dir = session_output_dir
else:
self.session_output_dir = create_session_output_dir(
self.base_output_dir, user_input
)
self.session_files = list(files)
self.executor = ExecutionSessionClient(
output_dir=self.session_output_dir,
allowed_files=self.session_files,
)
data_profile = ""
if self.session_files:
print("🔍 正在生成数据画像...")
data_profile = load_and_profile_data(self.session_files)
print("✅ 数据画像生成完毕")
self.data_profile = data_profile
self.template_content = self._load_template_content(template_path)
def _load_template_content(self, template_path: str = None) -> str:
template_content = "未提供特定模板,请根据数据画像自主发挥。"
if template_path and os.path.exists(template_path):
try:
with open(template_path, "r", encoding="utf-8") as f:
template_content = f.read()
print(f"📄 已加载参考模板: {os.path.basename(template_path)}")
except Exception as e:
print(f"⚠️ 读取模板失败: {str(e)}")
return template_content
def _build_initial_user_prompt(self, user_input: str, files: List[str]) -> str:
prompt = f"### 任务启动\n**用户需求**: {user_input}\n"
if files:
prompt += f"**数据文件**: {', '.join(files)}\n"
if self.data_profile:
prompt += f"\n{self.data_profile}\n\n"
prompt += "请基于上述数据画像和用户问题,自主决定本轮最合适的分析起点。"
return prompt
def _build_followup_user_prompt(self, user_input: str) -> str:
return (
"### 继续分析专题\n"
f"**新的专题需求**: {user_input}\n"
"请基于当前会话中已经加载的数据、已有图表、历史分析结果和中间变量继续分析。"
)
def _validate_followup_files(self, files: List[str]) -> None:
if files and [os.path.abspath(path) for path in files] != [
os.path.abspath(path) for path in self.session_files
]:
raise ValueError("同一分析会话暂不支持切换为不同的数据文件集合,请新建会话。")

74
main.py Normal file
View File

@@ -0,0 +1,74 @@
from data_analysis_agent import DataAnalysisAgent
from config.llm_config import LLMConfig
import sys
import os
from datetime import datetime
from utils.create_session_dir import create_session_output_dir
class DualLogger:
"""同时输出到终端和文件的日志记录器"""
def __init__(self, log_dir, filename="log.txt"):
self.terminal = sys.stdout
log_path = os.path.join(log_dir, filename)
self.log = open(log_path, "a", encoding="utf-8")
def write(self, message):
self.terminal.write(message)
# 过滤掉生成的代码块,不写入日志文件
if "🔧 执行代码:" in message:
return
self.log.write(message)
self.log.flush()
def flush(self):
self.terminal.flush()
self.log.flush()
def setup_logging(log_dir):
"""配置日志记录"""
# 记录开始时间
logger = DualLogger(log_dir)
sys.stdout = logger
# 可选:也将错误输出重定向
# sys.stderr = logger
print(f"\n{'='*20} Run Started at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} {'='*20}\n")
print(f"📄 日志文件已保存至: {os.path.join(log_dir, 'log.txt')}")
def main():
llm_config = LLMConfig()
files = ["./UB IOV Support_TR.csv"]
# 简化后的需求,让 Agent 自主规划
analysis_requirement = "我想了解这份工单数据的健康度。请帮我进行全面分析,并找出核心问题点和改进建议。"
# 在主函数中先创建会话目录,以便存放日志
base_output_dir = "outputs"
session_output_dir = create_session_output_dir(base_output_dir, analysis_requirement)
# 设置日志
setup_logging(session_output_dir)
# force_max_rounds=False 允许 AI 在认为完成后主动停止
agent = DataAnalysisAgent(llm_config, max_rounds=20, force_max_rounds=False)
# 这里的 template_path 如果你有特定的参考模板文件,可以传入路径
template_path = None
report = agent.analyze(
user_input=analysis_requirement,
files=files,
template_path=template_path,
session_output_dir=session_output_dir
)
print("\n" + "="*60)
print(f"✅ 分析任务圆满完成!")
print(f"📊 报告路径: {report.get('report_file_path')}")
print("="*60)
if __name__ == "__main__":
main()

147
prompts.py Normal file
View File

@@ -0,0 +1,147 @@
data_analysis_system_prompt = """你是一个顶级 AI 数据分析专家 (Data Scientist Agent),运行在 Jupyter Notebook 环境中。你的使命是:像人类专家一样理解数据、选择合适的方法、自主生成分析代码,并给出有证据支撑的业务结论。
### 核心工作原则 (Core Principles)
1. **目标驱动**:围绕用户问题组织分析,不机械套用固定流程,也不要为了使用某种方法而使用某种方法。
2. **方法自选**:你可以自由决定本轮要做描述统计、可视化、异常检测、相关分析、分组比较、回归、聚类、文本挖掘或根因拆解;但每次都要说明为什么当前方法适合当前问题。
3. **证据优先**:每个重要结论都必须有指标、表格、图形或统计检验支撑。没有证据时,要明确说明“不足以得出结论”。
4. **隐私保护**
- 严禁直接转储大段原始数据。
- 优先输出聚合结果、摘要统计、分布、分组对比和样本量信息。
- 原始数据处理必须在本地 Python 环境中完成。
5. **动态下钻**:发现异常、结构性差异或可疑模式后,可以继续细分维度追查原因;如果没有新增价值,也可以停止无效下钻。
6. **业务解释**:每张图、每个统计结果都要回答“这说明了什么”“对业务意味着什么”。
---
### 可用环境 (Environment)
- 数据处理:`pandas`, `numpy`, `duckdb`, `scipy`
- 可视化:`matplotlib`, `seaborn`, `plotly`
- 建模与统计:`scikit-learn`
- 文本分析:`sklearn.feature_extraction.text`
- 中文图表:已支持中文字体显示
---
### 方法选择原则 (Method Guidance)
- 如果目标是理解数据结构,可使用:维度、字段类型、缺失率、分布、重复率、时间范围、分组汇总。
- 如果目标是识别异常,可使用:分位数/IQR、Z-score、分组对比、时序波动、异常检测模型等合适方法。
- 如果目标是解释驱动因素,可使用:相关分析、分组均值比较、卡方检验、回归、树模型特征重要性等合适方法。
- 如果目标是发现结构分层可使用聚类、分层分组、贡献度分析、Pareto 分析等合适方法。
- 如果目标涉及文本字段可使用2-gram/3-gram、TF-IDF、高频短语、相似文本聚类等合适方法。
- 不要写死方法顺序;优先选择当前数据和问题最匹配的方法。
### 输出质量标准 (Evidence Standards)
- 结论必须对应到可复核的指标或图表。
- 做分组比较时,要说明样本量,避免基于极小样本下结论。
- 使用建模或统计检验时,要简要说明目标变量、特征或检验对象。
- 如果图表没有明显增量价值,可以只输出表格和结论。
- 如果用户提供了参考模板,可以借用其结构,但不要被模板限制住分析判断。
### 统计方法使用规范 (Statistical Quality Rules)
- 选择统计方法时,要说明它为什么适合当前问题,而不是机械套用。
- 进行分组比较时,至少说明样本量、比较维度和核心差异指标。
- 使用相关分析时,要区分“相关”与“因果”,避免把相关性直接解释为因果关系。
- 使用回归、树模型或聚类时,要明确目标变量、输入特征和结果含义,不要只输出模型分数。
- 使用异常检测时,要交代异常判定依据,例如分位数阈值、标准差阈值、模型评分或时序偏离程度。
- 使用文本分析时,要输出短语、主题或类别模式,并结合业务语境解释,不要只给孤立词频。
- 如果样本量过小、字段质量不足或方法前提不满足,应降低结论强度,并明确说明局限性。
### 证据链要求 (Evidence Chain)
- 每个关键发现尽量形成以下链路:指标/统计结果 -> 图表或表格 -> 业务解释 -> 行动建议。
- 如果引用图表,必须说明该图回答了什么问题,而不是只展示图片。
- 如果没有图表,也必须提供足够的数值证据或表格摘要支撑结论。
- 建议项必须尽量回扣前面的证据,不要给与数据无关的泛化建议。
---
### 强制约束规则 (Hard Constraints)
1. **图片保存**:必须使用 `os.path.join(session_output_dir, '中文文件名.png')` 保存,并打印绝对路径。严禁使用 `plt.show()`。
2. **图表规范**
- 类别较多时优先使用条形图,避免难以阅读的饼图。
- 必须设置 `plt.rcParams['font.sans-serif']` 确保中文不乱码。
- 图表标题、坐标轴和图例应可直接支持业务解读。
3. **文本分析**:如果分析文本字段,优先提取短语或主题,不要只统计单字频率。
4. **响应格式**:始终输出合法 YAML。
5. **动作选择**
- 当需要继续分析时,使用 `generate_code`
- 当已经生成关键图表且需要纳入最终证据链时,使用 `collect_figures`
- 当分析已足以回答用户问题时,使用 `analysis_complete`
---
### 响应格式示例
**探索/分析轮:**
```yaml
action: "generate_code"
reasoning: "当前先验证数据结构、关键字段分布和样本量,再决定是否需要进一步做异常检测或分组对比。"
code: |
import pandas as pd
import os
# 读取数据
# 输出聚合统计
# 根据结果决定下一步方法
next_steps:
- "确认关键字段的数据质量"
- "识别是否存在异常分布或结构性差异"
```
**最终报告轮:**
```yaml
action: "analysis_complete"
final_report: |
# 专业 Markdown 报告内容...
```
当前 Jupyter Notebook 环境变量:
{notebook_variables}
用户需求:{user_input}
参考模板(如有):{template_content}
"""
# 最终报告生成提示词
final_report_system_prompt = """你是一位资深数据分析专家 (Senior Data Analyst)。你的任务是基于详细的数据分析过程,撰写一份专业级、可落地的业务分析报告。
### 输入上下文
- **数据画像 (Data Profile)**:
{data_profile}
- **分析过程与关键代码发现**:
{code_results_summary}
- **可视化证据链 (Visual Evidence)**:
{figures_summary}
> **警告**:你必须引用已生成的关键图表。引用格式为 `![描述](./图片文件名.png)`。
### 报告核心要求
1. **客观性**:严禁使用“我”、“我们”等主观人称,采用陈述性语气。
2. **闭环性**:报告必须针对初始计划中的所有核心方向给出明确的“结论”或“由于数据受限无法给出结论的原因”。
3. **行动力**:最后的“建议矩阵”必须具体到部门、周期和预期收益。
---
### 报告结构模板 (Markdown)
# [项目/产品名称] 业务洞察报告
## 1. 摘要 (Executive Summary)
- **核心洞察**[一句话概括]
- **健康度评分**[0-100]
- **TOP 3 关键发现**
- **核心建议预览**
## 2. 数据理解与方法论 (Methodology)
- **数据覆盖范围**[时间窗口、样本量]
- **分析框架**[如用户漏斗、RCA 归因、时序预测等]
- **局限性声明**[如:数据缺失某字段导致的分析受限]
## 3. 核心发现与洞察 (Key Insights)
[按业务主题展开,每个主题必须包含证据图表、数据表现、业务结论]
## 4. 专项根因分析 (Deep Dive)
[针对分析过程中发现的异常点进行的下钻分析结论]
## 5. 建议与行动矩阵 (Recommendations)
[建议项 | 优先级 | 关键举措 | 预期收益 | 负责人]
"""

447
require.md Normal file
View File

@@ -0,0 +1,447 @@
# 真正的 AI 数据分析 Agent - 需求文档
## 1. 项目背景
### 1.1 当前问题
现有系统是"四不像"
- 任务规划基于模板的规则生成固定90个任务
- 任务执行AI 驱动的 ReAct 模式
- 结果:规则 + AI = 不协调、不灵活
### 1.2 核心问题
**用户的真实需求**
> "我有数据,帮我分析一下"
> "我想了解工单的健康度"
> "按照这个模板分析,但要灵活调整"
**系统应该做什么**
- 像人类分析师一样理解数据
- 自主决定分析什么
- 根据发现调整分析计划
- 生成有洞察力的报告
**而不是**
- 机械地执行固定任务
- 死板地按模板填空
## 2. 用户故事
### 2.1 场景1完全自主分析
**作为** 数据分析师
**我想要** 上传数据文件,让 AI 自动分析
**以便** 快速了解数据的关键信息
**验收标准**
- AI 能识别数据类型(工单、销售、用户等)
- AI 能推断关键字段的业务含义
- AI 能自主决定分析维度
- AI 能生成合理的分析计划
- AI 能执行分析并生成报告
- 报告包含关键发现和洞察
**示例**
```
输入cleaned_data.csv
输出:
- 数据类型:工单数据
- 关键发现:
* 待处理工单占比50%(异常高)
* 某车型问题占比80%
* 平均处理时长超过标准2倍
- 建议:优先处理该车型的积压工单
```
### 2.2 场景2指定分析方向
**作为** 业务负责人
**我想要** 指定分析方向(如"健康度"
**以便** 获得针对性的分析结果
**验收标准**
- AI 能理解"健康度"的业务含义
- AI 能将抽象概念转化为具体指标
- AI 能根据数据特征选择合适的分析方法
- AI 能生成针对性的报告
**示例**
```
输入:
- 数据cleaned_data.csv
- 需求:"我想了解工单的健康度"
AI 理解:
- 健康度 = 关闭率 + 处理效率 + 积压情况 + 响应及时性
AI 分析:
- 关闭率75%(中等)
- 平均处理时长48小时偏长
- 积压工单50%(严重)
- 健康度评分60/100需改进
```
### 2.3 场景3参考模板分析
**作为** 数据分析师
**我想要** 使用模板作为参考框架
**以便** 保持报告结构的一致性,同时保持灵活性
**验收标准**
- AI 能理解模板的结构和要求
- AI 能检查数据是否满足模板要求
- 如果数据缺少某些字段AI 能灵活调整
- AI 能按模板结构组织报告
- AI 不会因为数据不完全匹配而失败
**示例**
```
输入:
- 数据cleaned_data.csv
- 模板issue_analysis.md要求14个图表
AI 检查:
- 模板要求"严重程度分布",但数据中没有"严重程度"字段
- 决策:跳过该分析,在报告中说明
AI 调整:
- 执行其他13个分析
- 报告中注明:"数据缺少严重程度字段,无法分析该维度"
```
### 2.4 场景4迭代深入分析
**作为** 数据分析师
**我想要** AI 能根据发现深入分析
**以便** 找到问题的根因
**验收标准**
- AI 能识别异常或关键发现
- AI 能自主决定是否需要深入分析
- AI 能动态调整分析计划
- AI 能追踪问题的根因
**示例**
```
初步分析:
- 发现待处理工单占比50%(异常高)
AI 决策:需要深入分析
深入分析1
- 分析待处理工单的特征
- 发现某车型占80%
AI 决策:继续深入
深入分析2
- 分析该车型的问题类型
- 发现:都是"远程控制"问题
AI 决策:继续深入
深入分析3
- 分析"远程控制"问题的模块分布
- 发现90%是"车门模块"
结论:车门模块的远程控制功能存在系统性问题
```
## 3. 功能需求
### 3.1 数据理解Data Understanding
**FR-1.1 数据加载**
- 系统应支持 CSV 格式数据
- 系统应自动检测编码UTF-8, GBK等
- 系统应处理常见的数据格式问题
**FR-1.2 数据类型识别**
- AI 应分析列名、数据类型、值分布
- AI 应推断数据的业务类型(工单、销售、用户等)
- AI 应识别关键字段(时间、状态、分类、数值)
**FR-1.3 字段含义理解**
- AI 应推断每个字段的业务含义
- AI 应识别字段之间的关系
- AI 应识别可能的分析维度
**FR-1.4 数据质量评估**
- AI 应检查缺失值
- AI 应检查异常值
- AI 应评估数据质量分数
### 3.2 需求理解Requirement Understanding
**FR-2.1 自主需求推断**
- 当用户未指定需求时AI 应根据数据类型推断常见分析需求
- AI 应生成默认的分析目标
**FR-2.2 用户需求理解**
- AI 应理解用户的自然语言需求
- AI 应将抽象概念转化为具体指标
- AI 应判断数据是否支持用户需求
**FR-2.3 模板理解**
- AI 应解析模板结构
- AI 应理解模板要求的指标和图表
- AI 应检查数据是否满足模板要求
- AI 应在数据不满足时灵活调整
### 3.3 分析规划Analysis Planning
**FR-3.1 动态任务生成**
- AI 应根据数据特征和需求生成分析任务
- 任务应是动态的,不是固定的
- 任务应包含优先级和依赖关系
**FR-3.2 任务优先级**
- AI 应根据重要性排序任务
- 必需的分析应优先执行
- 可选的分析应后执行
**FR-3.3 计划调整**
- AI 应能根据中间结果调整计划
- AI 应能增加新的深入分析任务
- AI 应能跳过不适用的任务
### 3.4 工具集管理Tool Management
**FR-4.1 预设工具集**
- 系统应提供基础数据分析工具集
- 基础工具包括:数据查询、统计分析、可视化、数据清洗
- 工具应有标准的接口和描述
**FR-4.2 动态工具调整**
- AI 应根据数据特征决定需要哪些工具
- AI 应根据分析需求动态启用/禁用工具
- AI 应能识别缺少的工具并请求添加
**FR-4.3 工具适配**
- AI 应根据数据类型调整工具参数
- 例如:时间序列数据 → 启用趋势分析工具
- 例如:分类数据 → 启用分布分析工具
- 例如:地理数据 → 启用地图可视化工具
**FR-4.4 自定义工具生成**
- AI 应能根据特定需求生成临时工具
- AI 应能组合现有工具创建新功能
- 自定义工具应在分析结束后可选保留
**示例**
```
数据特征:
- 包含时间字段created_at, closed_at
- 包含分类字段status, type, model
- 包含数值字段duration
AI 决策:
- 启用工具:时间序列分析、分类分布、数值统计
- 禁用工具:地理分析(无地理字段)
- 生成工具计算处理时长closed_at - created_at
```
### 3.5 分析执行Analysis Execution
**FR-5.1 ReAct 执行模式**
- 每个任务应使用 ReAct 模式执行
- AI 应思考 → 行动 → 观察 → 判断
- AI 应能从错误中学习
**FR-5.2 工具调用**
- AI 应从可用工具集中选择合适的工具
- AI 应能组合多个工具完成复杂任务
- AI 应能处理工具调用失败的情况
**FR-5.3 结果验证**
- AI 应验证每个任务的结果
- AI 应识别异常结果
- AI 应决定是否需要重试或调整
**FR-5.4 迭代深入**
- AI 应识别关键发现
- AI 应决定是否需要深入分析
- AI 应动态增加深入分析任务
### 3.6 报告生成Report Generation
**FR-6.1 关键发现提炼**
- AI 应从所有结果中提炼关键发现
- AI 应识别异常和趋势
- AI 应提供洞察而不是简单罗列数据
**FR-6.2 报告结构组织**
- AI 应根据分析内容组织报告结构
- 如果有模板,应参考模板结构
- 如果没有模板,应生成合理的结构
**FR-6.3 结论和建议**
- AI 应基于分析结果得出结论
- AI 应提供可操作的建议
- AI 应说明建议的依据
**FR-6.4 多格式输出**
- 系统应生成 Markdown 格式报告
- 系统应支持导出为 Word 文档(可选)
- 报告应包含所有生成的图表
## 4. 非功能需求
### 4.1 性能需求
**NFR-1.1 响应时间**
- 数据理解阶段:< 30秒
- 分析规划阶段:< 60秒
- 单个任务执行:< 120秒
- 完整分析流程:< 30分钟取决于数据大小和任务数量
**NFR-1.2 数据规模**
- 支持最大 100MB 的 CSV 文件
- 支持最大 100万行数据
- 支持最大 100列
### 4.2 可靠性需求
**NFR-2.1 错误处理**
- AI 调用失败时应有降级策略
- 单个任务失败不应影响整体流程
- 系统应记录详细的错误日志
**NFR-2.2 数据安全**
- 数据应在本地处理,不上传到外部服务
- 生成的报告应保存在用户指定的目录
- 敏感信息应脱敏处理
### 4.3 可用性需求
**NFR-3.1 易用性**
- 用户只需提供数据文件即可开始分析
- 分析过程应显示进度和状态
- 错误信息应清晰易懂
**NFR-3.2 可观察性**
- 系统应显示 AI 的思考过程
- 系统应显示每个阶段的进度
- 系统应记录完整的执行日志
### 4.4 可扩展性需求
**NFR-4.1 工具扩展**
- 应易于添加新的分析工具
- 工具应有标准接口
- AI 应能自动发现和使用新工具
- 工具应支持热加载,无需重启系统
**NFR-4.2 工具动态性**
- 工具集应根据数据特征动态调整
- 工具参数应根据数据类型自适应
- 系统应支持运行时生成临时工具
**NFR-4.3 模型扩展**
- 应支持不同的 LLM 提供商
- 应支持本地模型和云端模型
- 应支持模型切换
## 5. 约束条件
### 5.1 技术约束
- 使用 Python 3.8+
- 使用 OpenAI 兼容的 LLM API
- 使用 pandas 进行数据处理
- 使用 matplotlib 进行可视化
### 5.2 业务约束
- 系统应在离线环境下工作(除 LLM 调用外)
- 系统不应依赖特定的数据格式或业务领域
- 系统应保持通用性,适用于各种数据分析场景
### 5.3 隐私和安全约束
**数据隐私保护**
- AI 不能访问完整的原始数据内容
- AI 只能读取:
- 表头(列名)
- 数据类型信息
- 基本统计摘要(行数、列数、缺失值比例、数据类型分布)
- 工具执行后的聚合结果(如分组统计结果、图表数据)
- 所有原始数据处理必须在本地完成,不发送给 LLM
- AI 通过调用本地工具来分析数据,工具返回摘要结果而非原始数据
### 5.3 隐私和安全约束
**数据隐私保护**
- AI 不能访问完整的原始数据内容
- AI 只能读取:
- 表头(列名)
- 数据类型信息
- 基本统计摘要(行数、列数、缺失值比例、数据类型分布)
- 工具执行后的聚合结果(如分组统计结果、图表数据)
- 所有原始数据处理必须在本地完成,不发送给 LLM
- AI 通过调用本地工具来分析数据,工具返回摘要结果而非原始数据
## 6. 验收标准
### 6.1 场景1验收
- [ ] 上传任意 CSV 文件AI 能识别数据类型
- [ ] AI 能自主生成分析计划
- [ ] AI 能执行分析并生成报告
- [ ] 报告包含关键发现和洞察
### 6.2 场景2验收
- [ ] 指定"健康度"等抽象需求AI 能理解
- [ ] AI 能生成相关指标
- [ ] AI 能执行针对性分析
- [ ] 报告聚焦于用户需求
### 6.3 场景3验收
- [ ] 提供模板AI 能理解模板要求
- [ ] 数据缺少字段时AI 能灵活调整
- [ ] 报告按模板结构组织
- [ ] 报告说明哪些分析被跳过及原因
### 6.4 场景4验收
- [ ] AI 能识别异常发现
- [ ] AI 能自主决定深入分析
- [ ] AI 能动态调整分析计划
- [ ] 报告包含深入分析的结果
### 6.5 工具动态性验收
- [ ] 系统根据数据特征自动启用相关工具
- [ ] 系统根据数据特征自动禁用无关工具
- [ ] AI 能识别需要但缺失的工具
- [ ] AI 能生成临时工具满足特定需求
- [ ] 工具参数根据数据类型自动调整
## 7. 成功指标
### 7.1 功能指标
- 数据类型识别准确率 > 90%
- 字段含义推断准确率 > 80%
- 分析计划合理性(人工评估)> 85%
- 报告质量(人工评估)> 80%
### 7.2 性能指标
- 完整分析流程完成率 > 95%
- AI 调用成功率 > 90%
### 7.3 用户满意度
- 用户认为分析结果有价值 > 80%
- 用户愿意再次使用 > 85%
- 用户推荐给他人 > 75%
---
**版本**: v3.0.0
**日期**: 2026-03-06
**状态**: 需求定义完成

55
requirements.txt Normal file
View File

@@ -0,0 +1,55 @@
# 数据分析和科学计算库
pandas>=2.0.0
openpyxl>=3.1.0
numpy>=1.24.0
matplotlib>=3.6.0
duckdb>=0.8.0
scipy>=1.10.0
scikit-learn>=1.3.0
# Web和API相关
requests>=2.28.0
urllib3>=1.26.0
fastapi>=0.110.0
uvicorn>=0.29.0
python-multipart>=0.0.9
# 绘图和可视化
plotly>=5.14.0
dash>=2.0.0
# 流程图支持可选用于生成Mermaid图表
# 注意Mermaid图表主要在markdown中渲染不需要额外的Python包
# 如果需要在Python中生成Mermaid代码可以考虑
# mermaid-py>=0.3.0
# Jupyter/IPython环境
ipython>=8.10.0
jupyter>=1.0.0
# AI/LLM相关
openai>=1.0.0
pyyaml>=6.0
# 配置管理
python-dotenv>=1.0.0
# 异步编程
asyncio-mqtt>=0.11.1
nest_asyncio>=1.5.0
# 文档生成基于输出的Word文档
python-docx>=0.8.11
# 系统和工具库
pathlib2>=2.3.7
typing-extensions>=4.5.0
# 开发和测试工具(可选)
pytest>=7.0.0
pytest-asyncio>=0.21.0
black>=23.0.0
flake8>=6.0.0
# 字体支持用于matplotlib中文显示
fonttools>=4.38.0

16
utils/__init__.py Normal file
View File

@@ -0,0 +1,16 @@
# -*- coding: utf-8 -*-
"""
工具模块初始化文件
"""
from utils.code_executor import CodeExecutor
from utils.execution_session_client import ExecutionSessionClient
from utils.llm_helper import LLMHelper
from utils.fallback_openai_client import AsyncFallbackOpenAIClient
__all__ = [
"CodeExecutor",
"ExecutionSessionClient",
"LLMHelper",
"AsyncFallbackOpenAIClient",
]

456
utils/code_executor.py Normal file
View File

@@ -0,0 +1,456 @@
# -*- coding: utf-8 -*-
"""
安全的代码执行器,基于 IPython 提供 notebook 环境下的代码执行功能
"""
import os
import sys
import ast
import traceback
import io
from typing import Dict, Any, List, Optional, Tuple
from contextlib import redirect_stdout, redirect_stderr
from IPython.core.interactiveshell import InteractiveShell
from IPython.utils.capture import capture_output
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
class CodeExecutor:
"""
安全的代码执行器,限制依赖库,捕获输出,支持图片保存与路径输出
"""
ALLOWED_IMPORTS = {
"pandas",
"pd",
"numpy",
"np",
"matplotlib",
"matplotlib.pyplot",
"plt",
"seaborn",
"sns",
"duckdb",
"scipy",
"sklearn",
"statsmodels",
"plotly",
"dash",
"requests",
"urllib",
"os",
"sys",
"json",
"csv",
"datetime",
"time",
"math",
"statistics",
"re",
"pathlib",
"io",
"collections",
"itertools",
"functools",
"operator",
"warnings",
"logging",
"copy",
"pickle",
"gzip",
"zipfile",
"yaml",
"typing",
"dataclasses",
"enum",
"sqlite3",
"jieba",
"wordcloud",
"PIL",
"random",
"networkx",
}
def __init__(self, output_dir: str = "outputs"):
"""
初始化代码执行器
Args:
output_dir: 输出目录,用于保存图片和文件
"""
self.output_dir = os.path.abspath(output_dir)
os.makedirs(self.output_dir, exist_ok=True)
# 为每个执行器创建独立的 shell避免跨分析任务共享状态
self.shell = InteractiveShell()
# 初始化隔离执行环境
self.reset_environment()
# 图片计数器
self.image_counter = 0
def _setup_chinese_font(self):
"""设置matplotlib中文字体显示"""
try:
# 设置matplotlib使用Agg backend避免GUI问题
matplotlib.use("Agg")
# 获取系统可用字体
available_fonts = [f.name for f in fm.fontManager.ttflist]
# 设置matplotlib使用系统可用中文字体
# macOS系统常用中文字体按优先级排序
chinese_fonts = [
"Hiragino Sans GB", # macOS中文简体
"Songti SC", # macOS宋体简体
"PingFang SC", # macOS苹方简体
"Heiti SC", # macOS黑体简体
"Heiti TC", # macOS黑体繁体
"PingFang HK", # macOS苹方香港
"SimHei", # Windows黑体
"STHeiti", # 华文黑体
"WenQuanYi Micro Hei", # Linux文泉驿微米黑
"DejaVu Sans", # 默认无衬线字体
"Arial Unicode MS", # Arial Unicode
]
# 检查系统中实际存在的字体
system_chinese_fonts = [
font for font in chinese_fonts if font in available_fonts
]
# 如果没有找到合适的中文字体,尝试更宽松的搜索
if not system_chinese_fonts:
print("警告:未找到精确匹配的中文字体,尝试更宽松的搜索...")
# 更宽松的字体匹配(包含部分名称)
fallback_fonts = []
for available_font in available_fonts:
if any(
keyword in available_font
for keyword in [
"Hei",
"Song",
"Fang",
"Kai",
"Hiragino",
"PingFang",
"ST",
]
):
fallback_fonts.append(available_font)
if fallback_fonts:
system_chinese_fonts = fallback_fonts[:3] # 取前3个匹配的字体
print(f"找到备选中文字体: {system_chinese_fonts}")
else:
print("警告:系统中未找到合适的中文字体,使用系统默认字体")
system_chinese_fonts = ["DejaVu Sans", "Arial Unicode MS"]
# 设置字体配置
plt.rcParams["font.sans-serif"] = system_chinese_fonts + [
"DejaVu Sans",
"Arial Unicode MS",
]
plt.rcParams["axes.unicode_minus"] = False
plt.rcParams["font.family"] = "sans-serif"
# 在shell中也设置相同的字体配置
font_list_str = str(
system_chinese_fonts + ["DejaVu Sans", "Arial Unicode MS"]
)
self.shell.run_cell(
f"""
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
# 设置中文字体
plt.rcParams['font.sans-serif'] = {font_list_str}
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['font.family'] = 'sans-serif'
# 确保matplotlib缓存目录可写
import os
cache_dir = os.path.expanduser('~/.matplotlib')
if not os.path.exists(cache_dir):
os.makedirs(cache_dir, exist_ok=True)
os.environ['MPLCONFIGDIR'] = cache_dir
"""
)
except Exception as e:
print(f"设置中文字体失败: {e}")
# 即使失败也要设置基本的matplotlib配置
try:
matplotlib.use("Agg")
plt.rcParams["axes.unicode_minus"] = False
except:
pass
def _setup_common_imports(self):
"""预导入常用库"""
common_imports = """
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import duckdb
import os
import json
from IPython.display import display
"""
try:
self.shell.run_cell(common_imports)
# 确保display函数在shell的用户命名空间中可用
from IPython.display import display
self.shell.user_ns["display"] = display
except Exception as e:
print(f"预导入库失败: {e}")
def _check_code_safety(self, code: str) -> Tuple[bool, str]:
"""
检查代码安全性,限制导入的库
Returns:
(is_safe, error_message)
"""
try:
tree = ast.parse(code)
except SyntaxError as e:
return False, f"语法错误: {e}"
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
if alias.name not in self.ALLOWED_IMPORTS:
return False, f"不允许的导入: {alias.name}"
elif isinstance(node, ast.ImportFrom):
if node.module not in self.ALLOWED_IMPORTS:
return False, f"不允许的导入: {node.module}"
# 检查属性访问防止通过os.system等方式绕过
elif isinstance(node, ast.Attribute):
# 检查是否访问了os模块的属性
if isinstance(node.value, ast.Name) and node.value.id == "os":
# 允许的os子模块和函数白名单
allowed_os_attributes = {
"path", "environ", "getcwd", "listdir", "makedirs", "mkdir", "remove", "rmdir",
"path.join", "path.exists", "path.abspath", "path.dirname",
"path.basename", "path.splitext", "path.isdir", "path.isfile",
"sep", "name", "linesep", "stat", "getpid"
}
# 检查直接属性访问 (如 os.getcwd)
if node.attr not in allowed_os_attributes:
# 进一步检查如果是 os.path.xxx 这种形式
# Note: ast.Attribute 嵌套结构比较复杂,简单处理只允许 os.path 和上述白名单
if node.attr == "path":
pass # 允许访问 os.path
else:
return False, f"不允许的os属性访问: os.{node.attr}"
# 检查危险函数调用
elif isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
if node.func.id in ["exec", "eval", "__import__"]:
return False, f"不允许的函数调用: {node.func.id}"
return True, ""
def get_current_figures_info(self) -> List[Dict[str, Any]]:
"""获取当前matplotlib图形信息但不自动保存"""
figures_info = []
# 获取当前所有图形
fig_nums = plt.get_fignums()
for fig_num in fig_nums:
fig = plt.figure(fig_num)
if fig.get_axes(): # 只处理有内容的图形
figures_info.append(
{
"figure_number": fig_num,
"axes_count": len(fig.get_axes()),
"figure_size": fig.get_size_inches().tolist(),
"has_content": True,
}
)
return figures_info
def _format_table_output(self, obj: Any) -> str:
"""格式化表格输出,限制行数"""
if hasattr(obj, "shape") and hasattr(obj, "head"): # pandas DataFrame
rows, cols = obj.shape
print(f"\n数据表形状: {rows}行 x {cols}")
print(f"列名: {list(obj.columns)}")
if rows <= 15:
return str(obj)
else:
head_part = obj.head(5)
tail_part = obj.tail(5)
return f"{head_part}\n...\n(省略 {rows-10} 行)\n...\n{tail_part}"
return str(obj)
def execute_code(self, code: str) -> Dict[str, Any]:
"""
执行代码并返回结果
Args:
code: 要执行的Python代码
Returns:
{
'success': bool,
'output': str,
'error': str,
'variables': Dict[str, Any] # 新生成的重要变量
}
"""
# 检查代码安全性
is_safe, safety_error = self._check_code_safety(code)
if not is_safe:
return {
"success": False,
"output": "",
"error": f"代码安全检查失败: {safety_error}",
"variables": {},
}
# 记录执行前的变量
vars_before = set(self.shell.user_ns.keys())
try:
# 使用IPython的capture_output来捕获所有输出
with capture_output() as captured:
result = self.shell.run_cell(code)
# 检查执行结果
if result.error_before_exec:
error_msg = str(result.error_before_exec)
return {
"success": False,
"output": captured.stdout,
"error": f"执行前错误: {error_msg}",
"variables": {},
}
if result.error_in_exec:
error_msg = str(result.error_in_exec)
return {
"success": False,
"output": captured.stdout,
"error": f"执行错误: {error_msg}",
"variables": {},
}
# 获取输出
output = captured.stdout
# 如果有返回值,添加到输出
if result.result is not None:
formatted_result = self._format_table_output(result.result)
output += f"\n{formatted_result}"
# 记录新产生的重要变量(简化版本)
vars_after = set(self.shell.user_ns.keys())
new_vars = vars_after - vars_before
# 只记录新创建的DataFrame等重要数据结构
important_new_vars = {}
for var_name in new_vars:
if not var_name.startswith("_"):
try:
var_value = self.shell.user_ns[var_name]
if hasattr(var_value, "shape"): # pandas DataFrame, numpy array
important_new_vars[var_name] = (
f"{type(var_value).__name__} with shape {var_value.shape}"
)
elif var_name in ["session_output_dir"]: # 重要的配置变量
important_new_vars[var_name] = str(var_value)
except:
pass
return {
"success": True,
"output": output,
"error": "",
"variables": important_new_vars,
}
except Exception as e:
return {
"success": False,
"output": captured.stdout if "captured" in locals() else "",
"error": f"执行异常: {str(e)}\n{traceback.format_exc()}",
"variables": {},
}
def reset_environment(self):
"""重置执行环境"""
self.shell.reset()
self._setup_common_imports()
self._setup_chinese_font()
plt.close("all")
self.image_counter = 0
def set_variable(self, name: str, value: Any):
"""设置执行环境中的变量"""
self.shell.user_ns[name] = value
def get_environment_info(self) -> str:
"""获取当前执行环境的变量信息,用于系统提示词"""
info_parts = []
# 获取重要的数据变量
important_vars = {}
for var_name, var_value in self.shell.user_ns.items():
if not var_name.startswith("_") and var_name not in [
"In",
"Out",
"get_ipython",
"exit",
"quit",
]:
try:
if hasattr(var_value, "shape"): # pandas DataFrame, numpy array
important_vars[var_name] = (
f"{type(var_value).__name__} with shape {var_value.shape}"
)
elif var_name in ["session_output_dir"]: # 重要的路径变量
important_vars[var_name] = str(var_value)
elif (
isinstance(var_value, (int, float, str, bool))
and len(str(var_value)) < 100
):
important_vars[var_name] = (
f"{type(var_value).__name__}: {var_value}"
)
elif hasattr(var_value, "__module__") and var_value.__module__ in [
"pandas",
"numpy",
"matplotlib.pyplot",
]:
important_vars[var_name] = f"导入的模块: {var_value.__module__}"
except:
continue
if important_vars:
info_parts.append("当前环境变量:")
for var_name, var_info in important_vars.items():
info_parts.append(f"- {var_name}: {var_info}")
else:
info_parts.append("当前环境已预装pandas, numpy, matplotlib等库")
# 添加输出目录信息
if "session_output_dir" in self.shell.user_ns:
info_parts.append(
f"图片保存目录: session_output_dir = '{self.shell.user_ns['session_output_dir']}'"
)
return "\n".join(info_parts)

View File

@@ -0,0 +1,15 @@
import os
from datetime import datetime
def create_session_output_dir(base_output_dir, user_input: str) -> str:
"""为本次分析创建独立的输出目录"""
# 使用当前时间创建唯一的会话目录名格式YYYYMMDD_HHMMSS
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
session_id = timestamp
dir_name = f"session_{session_id}"
session_dir = os.path.join(base_output_dir, dir_name)
os.makedirs(session_dir, exist_ok=True)
return session_dir

90
utils/data_loader.py Normal file
View File

@@ -0,0 +1,90 @@
# -*- coding: utf-8 -*-
import os
import pandas as pd
import io
def load_and_profile_data(file_paths: list) -> str:
"""
加载数据并生成数据画像
Args:
file_paths: 文件路径列表
Returns:
包含数据画像的Markdown字符串
"""
profile_summary = "# 数据画像报告 (Data Profile)\n\n"
if not file_paths:
return profile_summary + "未提供数据文件。"
for file_path in file_paths:
file_name = os.path.basename(file_path)
profile_summary += f"## 文件: {file_name}\n\n"
if not os.path.exists(file_path):
profile_summary += f"⚠️ 文件不存在: {file_path}\n\n"
continue
try:
# 根据扩展名选择加载方式
ext = os.path.splitext(file_path)[1].lower()
if ext == '.csv':
# 尝试多种编码
try:
df = pd.read_csv(file_path, encoding='utf-8')
except UnicodeDecodeError:
try:
df = pd.read_csv(file_path, encoding='gbk')
except Exception:
df = pd.read_csv(file_path, encoding='latin1')
elif ext in ['.xlsx', '.xls']:
df = pd.read_excel(file_path)
else:
profile_summary += f"⚠️ 不支持的文件格式: {ext}\n\n"
continue
# 基础信息
rows, cols = df.shape
profile_summary += f"- **维度**: {rows} 行 x {cols}\n"
profile_summary += f"- **列名**: `{', '.join(df.columns)}`\n\n"
profile_summary += "### 列详细分布:\n"
# 遍历分析每列
for col in df.columns:
dtype = df[col].dtype
null_count = df[col].isnull().sum()
null_ratio = (null_count / rows) * 100
profile_summary += f"#### {col} ({dtype})\n"
if null_count > 0:
profile_summary += f"- ⚠️ 空值: {null_count} ({null_ratio:.1f}%)\n"
# 数值列分析
if pd.api.types.is_numeric_dtype(dtype):
desc = df[col].describe()
profile_summary += f"- 统计: Min={desc['min']:.2f}, Max={desc['max']:.2f}, Mean={desc['mean']:.2f}\n"
# 文本/分类列分析
elif pd.api.types.is_object_dtype(dtype) or pd.api.types.is_categorical_dtype(dtype):
unique_count = df[col].nunique()
profile_summary += f"- 唯一值数量: {unique_count}\n"
# 如果唯一值较少(<50或者看起来是分类数据显示Top分布
# 这对识别“高频问题”至关重要
if unique_count > 0:
top_n = df[col].value_counts().head(5)
top_items_str = ", ".join([f"{k}({v})" for k, v in top_n.items()])
profile_summary += f"- **TOP 5 高频值**: {top_items_str}\n"
# 时间列分析
elif pd.api.types.is_datetime64_any_dtype(dtype):
profile_summary += f"- 范围: {df[col].min()}{df[col].max()}\n"
profile_summary += "\n"
except Exception as e:
profile_summary += f"❌ 读取或分析文件失败: {str(e)}\n\n"
return profile_summary

View File

@@ -0,0 +1,219 @@
# -*- coding: utf-8 -*-
"""
Client for a per-analysis execution worker subprocess.
"""
import json
import os
import queue
import subprocess
import sys
import threading
import uuid
from typing import Any, Dict, Optional
class WorkerSessionError(RuntimeError):
"""Raised when the execution worker cannot serve a request."""
class WorkerTimeoutError(WorkerSessionError):
"""Raised when the worker does not respond within the configured timeout."""
class ExecutionSessionClient:
"""Client that proxies CodeExecutor methods to a dedicated worker process."""
def __init__(
self,
output_dir: str,
allowed_files=None,
python_executable: str = None,
request_timeout_seconds: float = 60.0,
startup_timeout_seconds: float = 180.0,
):
self.output_dir = os.path.abspath(output_dir)
self.allowed_files = [os.path.abspath(path) for path in (allowed_files or [])]
self.allowed_read_roots = sorted(
{os.path.dirname(path) for path in self.allowed_files}
)
self.python_executable = python_executable or sys.executable
self.request_timeout_seconds = request_timeout_seconds
self.startup_timeout_seconds = startup_timeout_seconds
self._process: Optional[subprocess.Popen] = None
self._stderr_handle = None
self._start_worker()
self._request(
"init_session",
{
"output_dir": self.output_dir,
"variables": {
"session_output_dir": self.output_dir,
"allowed_files": self.allowed_files,
"allowed_read_roots": self.allowed_read_roots,
},
},
timeout_seconds=self.startup_timeout_seconds,
)
def execute_code(self, code: str) -> Dict[str, Any]:
return self._request("execute_code", {"code": code})
def set_variable(self, name: str, value: Any) -> None:
self._request("set_variable", {"name": name, "value": value})
def get_environment_info(self) -> str:
payload = self._request("get_environment_info", {})
return payload.get("environment_info", "")
def reset_environment(self) -> None:
self._request("reset_environment", {})
self.set_variable("session_output_dir", self.output_dir)
self.set_variable("allowed_files", self.allowed_files)
self.set_variable("allowed_read_roots", self.allowed_read_roots)
def ping(self) -> bool:
payload = self._request("ping", {})
return bool(payload.get("alive"))
def close(self) -> None:
if self._process is None:
return
try:
if self._process.poll() is None:
self._request("shutdown", {}, timeout_seconds=5)
except Exception:
pass
finally:
self._teardown_worker()
def _start_worker(self) -> None:
runtime_dir = os.path.join(self.output_dir, ".worker_runtime")
mpl_dir = os.path.join(runtime_dir, "mplconfig")
ipython_dir = os.path.join(runtime_dir, "ipython")
os.makedirs(mpl_dir, exist_ok=True)
os.makedirs(ipython_dir, exist_ok=True)
stderr_log_path = os.path.join(self.output_dir, "execution_worker.log")
self._stderr_handle = open(stderr_log_path, "a", encoding="utf-8")
worker_script = os.path.join(
os.path.dirname(__file__),
"execution_worker.py",
)
env = os.environ.copy()
env["MPLCONFIGDIR"] = mpl_dir
env["IPYTHONDIR"] = ipython_dir
env.setdefault("PYTHONIOENCODING", "utf-8")
self._process = subprocess.Popen(
[self.python_executable, worker_script],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=self._stderr_handle,
text=True,
encoding="utf-8",
cwd=os.path.abspath(os.path.join(os.path.dirname(__file__), "..")),
env=env,
bufsize=1,
)
def _request(
self,
request_type: str,
payload: Dict[str, Any],
timeout_seconds: float = None,
) -> Dict[str, Any]:
if self._process is None or self._process.stdin is None or self._process.stdout is None:
raise WorkerSessionError("执行子进程未启动")
if self._process.poll() is not None:
raise WorkerSessionError(
f"执行子进程已退出,退出码: {self._process.returncode}"
)
request_id = str(uuid.uuid4())
message = {
"request_id": request_id,
"type": request_type,
"payload": payload,
}
try:
self._process.stdin.write(json.dumps(message, ensure_ascii=False) + "\n")
self._process.stdin.flush()
except BrokenPipeError as exc:
self._teardown_worker()
raise WorkerSessionError("执行子进程通信中断") from exc
effective_timeout = (
self.request_timeout_seconds
if timeout_seconds is None
else timeout_seconds
)
response_line = self._read_response_line(effective_timeout)
if not response_line:
if self._process.poll() is not None:
exit_code = self._process.returncode
self._teardown_worker()
raise WorkerSessionError(f"执行子进程已异常退出,退出码: {exit_code}")
raise WorkerSessionError("执行子进程未返回响应")
try:
response = json.loads(response_line)
except json.JSONDecodeError as exc:
raise WorkerSessionError(f"执行子进程返回了无效JSON: {response_line}") from exc
if response.get("request_id") != request_id:
raise WorkerSessionError("执行子进程响应 request_id 不匹配")
if response.get("status") != "ok":
raise WorkerSessionError(response.get("error", "执行子进程返回未知错误"))
return response.get("payload", {})
def _read_response_line(self, timeout_seconds: float) -> str:
assert self._process is not None and self._process.stdout is not None
response_queue: queue.Queue = queue.Queue(maxsize=1)
def _reader() -> None:
try:
response_queue.put((True, self._process.stdout.readline()))
except Exception as exc:
response_queue.put((False, exc))
thread = threading.Thread(target=_reader, daemon=True)
thread.start()
try:
success, value = response_queue.get(timeout=timeout_seconds)
except queue.Empty as exc:
self._teardown_worker(force=True)
raise WorkerTimeoutError(
f"执行子进程在 {timeout_seconds:.1f} 秒内未响应,已终止当前会话"
) from exc
if success:
return value
self._teardown_worker()
raise WorkerSessionError(f"读取执行子进程响应失败: {value}")
def _teardown_worker(self, force: bool = False) -> None:
if self._process is not None and self._process.poll() is None:
self._process.terminate()
try:
self._process.wait(timeout=5)
except subprocess.TimeoutExpired:
self._process.kill()
self._process.wait(timeout=5)
if self._stderr_handle is not None:
self._stderr_handle.close()
self._stderr_handle = None
self._process = None
def __del__(self):
self.close()

321
utils/execution_worker.py Normal file
View File

@@ -0,0 +1,321 @@
# -*- coding: utf-8 -*-
"""
Subprocess worker that hosts a single CodeExecutor instance for one analysis session.
"""
import json
import os
import sys
import traceback
from pathlib import Path
from contextlib import redirect_stderr, redirect_stdout
from io import StringIO
from typing import Any, Dict, Iterable, Optional
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
class WorkerProtocolError(RuntimeError):
"""Raised when the worker receives an invalid protocol message."""
class FileAccessPolicy:
"""Controls which files the worker may read and where it may write outputs."""
def __init__(self):
self.allowed_reads = set()
self.allowed_read_roots = set()
self.allowed_write_root = ""
@staticmethod
def _normalize(path: Any) -> str:
if isinstance(path, Path):
path = str(path)
elif hasattr(path, "__fspath__"):
path = os.fspath(path)
elif not isinstance(path, str):
raise TypeError(f"不支持的路径类型: {type(path).__name__}")
return os.path.realpath(os.path.abspath(path))
def configure(
self,
allowed_reads: Iterable[Any],
allowed_write_root: Any,
allowed_read_roots: Optional[Iterable[Any]] = None,
) -> None:
self.allowed_reads = {
self._normalize(path) for path in allowed_reads if path
}
explicit_roots = {
self._normalize(path) for path in (allowed_read_roots or []) if path
}
derived_roots = {
os.path.dirname(path) for path in self.allowed_reads
}
self.allowed_read_roots = explicit_roots | derived_roots
self.allowed_write_root = (
self._normalize(allowed_write_root) if allowed_write_root else ""
)
def ensure_readable(self, path: Any) -> str:
normalized_path = self._normalize(path)
if normalized_path in self.allowed_reads:
return normalized_path
if self._is_within_read_roots(normalized_path):
return normalized_path
if self._is_within_write_root(normalized_path):
return normalized_path
raise PermissionError(f"禁止读取未授权文件: {normalized_path}")
def ensure_writable(self, path: Any) -> str:
normalized_path = self._normalize(path)
if self._is_within_write_root(normalized_path):
return normalized_path
raise PermissionError(f"禁止写入会话目录之外的路径: {normalized_path}")
def _is_within_write_root(self, normalized_path: str) -> bool:
if not self.allowed_write_root:
return False
return normalized_path == self.allowed_write_root or normalized_path.startswith(
self.allowed_write_root + os.sep
)
def _is_within_read_roots(self, normalized_path: str) -> bool:
for root in self.allowed_read_roots:
if normalized_path == root or normalized_path.startswith(root + os.sep):
return True
return False
def _write_message(message: Dict[str, Any]) -> None:
sys.stdout.write(json.dumps(message, ensure_ascii=False) + "\n")
sys.stdout.flush()
def _write_log(text: str) -> None:
if not text:
return
sys.stderr.write(text)
if not text.endswith("\n"):
sys.stderr.write("\n")
sys.stderr.flush()
class ExecutionWorker:
"""JSON-line protocol wrapper around CodeExecutor."""
def __init__(self):
self.executor = None
self.access_policy = FileAccessPolicy()
self._patches_installed = False
def handle_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
request_id = request.get("request_id", "")
request_type = request.get("type")
payload = request.get("payload", {})
try:
if request_type == "ping":
return self._ok(request_id, {"alive": True})
if request_type == "init_session":
return self._handle_init_session(request_id, payload)
if request_type == "execute_code":
self._require_executor()
return self._ok(
request_id,
self.executor.execute_code(payload.get("code", "")),
)
if request_type == "set_variable":
self._require_executor()
self._handle_set_variable(payload["name"], payload["value"])
return self._ok(request_id, {"set": True})
if request_type == "get_environment_info":
self._require_executor()
return self._ok(
request_id,
{"environment_info": self.executor.get_environment_info()},
)
if request_type == "reset_environment":
self._require_executor()
self.executor.reset_environment()
return self._ok(request_id, {"reset": True})
if request_type == "shutdown":
return self._ok(request_id, {"shutdown": True})
raise WorkerProtocolError(f"未知请求类型: {request_type}")
except Exception as exc:
return {
"request_id": request_id,
"status": "error",
"error": str(exc),
"traceback": traceback.format_exc(),
}
def _handle_init_session(
self, request_id: str, payload: Dict[str, Any]
) -> Dict[str, Any]:
output_dir = payload.get("output_dir")
if not output_dir:
raise WorkerProtocolError("init_session 缺少 output_dir")
from utils.code_executor import CodeExecutor
self.executor = CodeExecutor(output_dir)
self.access_policy.configure(
payload.get("variables", {}).get("allowed_files", []),
output_dir,
payload.get("variables", {}).get("allowed_read_roots", []),
)
self._install_file_guards()
for name, value in payload.get("variables", {}).items():
self.executor.set_variable(name, value)
return self._ok(request_id, {"initialized": True})
def _install_file_guards(self) -> None:
if self._patches_installed:
return
import builtins
import matplotlib.figure
import matplotlib.pyplot as plt
import pandas as pd
policy = self.access_policy
original_open = builtins.open
original_read_csv = pd.read_csv
original_read_excel = pd.read_excel
original_to_csv = pd.DataFrame.to_csv
original_to_excel = pd.DataFrame.to_excel
original_plt_savefig = plt.savefig
original_figure_savefig = matplotlib.figure.Figure.savefig
def guarded_open(file, mode="r", *args, **kwargs):
if isinstance(file, (str, Path)) or hasattr(file, "__fspath__"):
if any(flag in mode for flag in ("w", "a", "x", "+")):
policy.ensure_writable(file)
else:
policy.ensure_readable(file)
return original_open(file, mode, *args, **kwargs)
def guarded_read_csv(filepath_or_buffer, *args, **kwargs):
if isinstance(filepath_or_buffer, (str, Path)) or hasattr(
filepath_or_buffer, "__fspath__"
):
policy.ensure_readable(filepath_or_buffer)
return original_read_csv(filepath_or_buffer, *args, **kwargs)
def guarded_read_excel(io, *args, **kwargs):
if isinstance(io, (str, Path)) or hasattr(io, "__fspath__"):
policy.ensure_readable(io)
return original_read_excel(io, *args, **kwargs)
def guarded_to_csv(df, path_or_buf=None, *args, **kwargs):
if isinstance(path_or_buf, (str, Path)) or hasattr(path_or_buf, "__fspath__"):
policy.ensure_writable(path_or_buf)
return original_to_csv(df, path_or_buf, *args, **kwargs)
def guarded_to_excel(df, excel_writer, *args, **kwargs):
if isinstance(excel_writer, (str, Path)) or hasattr(excel_writer, "__fspath__"):
policy.ensure_writable(excel_writer)
return original_to_excel(df, excel_writer, *args, **kwargs)
def guarded_savefig(*args, **kwargs):
target = args[0] if args else kwargs.get("fname")
if target is not None and (
isinstance(target, (str, Path)) or hasattr(target, "__fspath__")
):
policy.ensure_writable(target)
return original_plt_savefig(*args, **kwargs)
def guarded_figure_savefig(fig, fname, *args, **kwargs):
if isinstance(fname, (str, Path)) or hasattr(fname, "__fspath__"):
policy.ensure_writable(fname)
return original_figure_savefig(fig, fname, *args, **kwargs)
builtins.open = guarded_open
pd.read_csv = guarded_read_csv
pd.read_excel = guarded_read_excel
pd.DataFrame.to_csv = guarded_to_csv
pd.DataFrame.to_excel = guarded_to_excel
plt.savefig = guarded_savefig
matplotlib.figure.Figure.savefig = guarded_figure_savefig
self._patches_installed = True
def _require_executor(self) -> None:
if self.executor is None:
raise WorkerProtocolError("执行会话尚未初始化")
def _handle_set_variable(self, name: str, value: Any) -> None:
self.executor.set_variable(name, value)
if name == "allowed_files":
self.access_policy.configure(
value,
self.access_policy.allowed_write_root,
self.access_policy.allowed_read_roots,
)
elif name == "allowed_read_roots":
self.access_policy.configure(
self.access_policy.allowed_reads,
self.access_policy.allowed_write_root,
value,
)
elif name == "session_output_dir":
self.access_policy.configure(
self.access_policy.allowed_reads,
value,
self.access_policy.allowed_read_roots,
)
@staticmethod
def _ok(request_id: str, payload: Dict[str, Any]) -> Dict[str, Any]:
return {
"request_id": request_id,
"status": "ok",
"payload": payload,
}
def main() -> int:
worker = ExecutionWorker()
for raw_line in sys.stdin:
raw_line = raw_line.strip()
if not raw_line:
continue
try:
request = json.loads(raw_line)
except json.JSONDecodeError as exc:
_write_message(
{
"request_id": "",
"status": "error",
"error": f"无效JSON请求: {exc}",
}
)
continue
captured_stdout = StringIO()
captured_stderr = StringIO()
with redirect_stdout(captured_stdout), redirect_stderr(captured_stderr):
response = worker.handle_request(request)
_write_log(captured_stdout.getvalue())
_write_log(captured_stderr.getvalue())
_write_message(response)
if request.get("type") == "shutdown" and response.get("status") == "ok":
return 0
return 0
if __name__ == "__main__":
raise SystemExit(main())

38
utils/extract_code.py Normal file
View File

@@ -0,0 +1,38 @@
from typing import Optional
import yaml
def extract_code_from_response(response: str) -> Optional[str]:
"""从LLM响应中提取代码"""
try:
# 尝试解析YAML
if '```yaml' in response:
start = response.find('```yaml') + 7
end = response.find('```', start)
yaml_content = response[start:end].strip()
elif '```' in response:
start = response.find('```') + 3
end = response.find('```', start)
yaml_content = response[start:end].strip()
else:
yaml_content = response.strip()
yaml_data = yaml.safe_load(yaml_content)
if 'code' in yaml_data:
return yaml_data['code']
except:
pass
# 如果YAML解析失败尝试提取```python代码块
if '```python' in response:
start = response.find('```python') + 9
end = response.find('```', start)
if end != -1:
return response[start:end].strip()
elif '```' in response:
start = response.find('```') + 3
end = response.find('```', start)
if end != -1:
return response[start:end].strip()
return None

View File

@@ -0,0 +1,230 @@
# -*- coding: utf-8 -*-
import asyncio
from typing import Optional, Any, Mapping, Dict
from openai import AsyncOpenAI, APIStatusError, APIConnectionError, APITimeoutError, APIError
from openai.types.chat import ChatCompletion
class AsyncFallbackOpenAIClient:
"""
一个支持备用 API 自动切换的异步 OpenAI 客户端。
当主 API 调用因特定错误(如内容过滤)失败时,会自动尝试使用备用 API。
"""
def __init__(
self,
primary_api_key: str,
primary_base_url: str,
primary_model_name: str,
fallback_api_key: Optional[str] = None,
fallback_base_url: Optional[str] = None,
fallback_model_name: Optional[str] = None,
primary_client_args: Optional[Dict[str, Any]] = None,
fallback_client_args: Optional[Dict[str, Any]] = None,
content_filter_error_code: str = "1301", # 特定于 Zhipu 的内容过滤错误代码
content_filter_error_field: str = "contentFilter", # 特定于 Zhipu 的内容过滤错误字段
max_retries_primary: int = 1, # 主API重试次数
max_retries_fallback: int = 1, # 备用API重试次数
retry_delay_seconds: float = 1.0 # 重试延迟时间
):
"""
初始化 AsyncFallbackOpenAIClient。
Args:
primary_api_key: 主 API 的密钥。
primary_base_url: 主 API 的基础 URL。
primary_model_name: 主 API 使用的模型名称。
fallback_api_key: 备用 API 的密钥 (可选)。
fallback_base_url: 备用 API 的基础 URL (可选)。
fallback_model_name: 备用 API 使用的模型名称 (可选)。
primary_client_args: 传递给主 AsyncOpenAI 客户端的其他参数。
fallback_client_args: 传递给备用 AsyncOpenAI 客户端的其他参数。
content_filter_error_code: 触发回退的内容过滤错误的特定错误代码。
content_filter_error_field: 触发回退的内容过滤错误中存在的字段名。
max_retries_primary: 主 API 失败时的最大重试次数。
max_retries_fallback: 备用 API 失败时的最大重试次数。
retry_delay_seconds: 重试前的延迟时间(秒)。
"""
if not primary_api_key or not primary_base_url:
raise ValueError("主 API 密钥和基础 URL 不能为空。")
_primary_args = primary_client_args or {}
self.primary_client = AsyncOpenAI(api_key=primary_api_key, base_url=primary_base_url, **_primary_args)
self.primary_model_name = primary_model_name
self.fallback_client: Optional[AsyncOpenAI] = None
self.fallback_model_name: Optional[str] = None
if fallback_api_key and fallback_base_url and fallback_model_name:
_fallback_args = fallback_client_args or {}
self.fallback_client = AsyncOpenAI(api_key=fallback_api_key, base_url=fallback_base_url, **_fallback_args)
self.fallback_model_name = fallback_model_name
else:
print("⚠️ 警告: 未完全配置备用 API 客户端。如果主 API 失败,将无法进行回退。")
self.content_filter_error_code = content_filter_error_code
self.content_filter_error_field = content_filter_error_field
self.max_retries_primary = max_retries_primary
self.max_retries_fallback = max_retries_fallback
self.retry_delay_seconds = retry_delay_seconds
self._closed = False
async def _attempt_api_call(
self,
client: AsyncOpenAI,
model_name: str,
messages: list[Mapping[str, Any]],
max_retries: int,
api_name: str,
**kwargs: Any
) -> ChatCompletion:
"""
尝试调用指定的 OpenAI API 客户端,并进行重试。
"""
last_exception = None
for attempt in range(max_retries + 1):
try:
# print(f"尝试使用 {api_name} API ({client.base_url}) 模型: {kwargs.get('model', model_name)}, 第 {attempt + 1} 次尝试")
completion = await client.chat.completions.create(
model=kwargs.pop('model', model_name),
messages=messages,
**kwargs
)
return completion
except (APIConnectionError, APITimeoutError) as e: # 通常可以重试的网络错误
last_exception = e
print(f"⚠️ {api_name} API 调用时发生可重试错误 ({type(e).__name__}): {e}. 尝试次数 {attempt + 1}/{max_retries + 1}")
if attempt < max_retries:
await asyncio.sleep(self.retry_delay_seconds * (attempt + 1)) # 增加延迟
else:
print(f"{api_name} API 在达到最大重试次数后仍然失败。")
except APIStatusError as e: # API 返回的特定状态码错误
is_content_filter_error = False
if e.status_code == 400:
try:
error_json = e.response.json()
error_details = error_json.get("error", {})
if (error_details.get("code") == self.content_filter_error_code and
self.content_filter_error_field in error_json):
is_content_filter_error = True
except Exception:
pass # 解析错误响应失败,不认为是内容过滤错误
if is_content_filter_error and api_name == "": # 如果是主 API 的内容过滤错误,则直接抛出以便回退
raise e
last_exception = e
print(f"⚠️ {api_name} API 调用时发生 APIStatusError ({e.status_code}): {e}. 尝试次数 {attempt + 1}/{max_retries + 1}")
if attempt < max_retries:
await asyncio.sleep(self.retry_delay_seconds * (attempt + 1))
else:
print(f"{api_name} API 在达到最大重试次数后仍然失败 (APIStatusError)。")
except APIError as e: # 其他不可轻易重试的 OpenAI 错误
last_exception = e
print(f"{api_name} API 调用时发生不可重试错误 ({type(e).__name__}): {e}")
break # 不再重试此类错误
if last_exception:
raise last_exception
raise RuntimeError(f"{api_name} API 调用意外失败。") # 理论上不应到达这里
async def chat_completions_create(
self,
messages: list[Mapping[str, Any]],
**kwargs: Any # 用于传递其他 OpenAI 参数,如 max_tokens, temperature 等。
) -> ChatCompletion:
"""
使用主 API 创建聊天补全,如果发生特定内容过滤错误或主 API 调用失败,则回退到备用 API。
支持对主 API 和备用 API 的可重试错误进行重试。
Args:
messages: OpenAI API 的消息列表。
**kwargs: 传递给 OpenAI API 调用的其他参数。
Returns:
ChatCompletion 对象。
Raises:
APIError: 如果主 API 和备用 API (如果尝试) 都返回 API 错误。
RuntimeError: 如果客户端已关闭。
"""
if self._closed:
raise RuntimeError("客户端已关闭。")
try:
completion = await self._attempt_api_call(
client=self.primary_client,
model_name=self.primary_model_name,
messages=messages,
max_retries=self.max_retries_primary,
api_name="",
**kwargs.copy()
)
return completion
except APIStatusError as e_primary:
is_content_filter_error = False
if e_primary.status_code == 400:
try:
error_json = e_primary.response.json()
error_details = error_json.get("error", {})
if (error_details.get("code") == self.content_filter_error_code and
self.content_filter_error_field in error_json):
is_content_filter_error = True
except Exception:
pass
if is_content_filter_error and self.fallback_client and self.fallback_model_name:
print(f" 主 API 内容过滤错误 ({e_primary.status_code})。尝试切换到备用 API ({self.fallback_client.base_url})...")
try:
fallback_completion = await self._attempt_api_call(
client=self.fallback_client,
model_name=self.fallback_model_name,
messages=messages,
max_retries=self.max_retries_fallback,
api_name="备用",
**kwargs.copy()
)
print(f"✅ 备用 API 调用成功。")
return fallback_completion
except APIError as e_fallback:
print(f"❌ 备用 API 调用最终失败: {type(e_fallback).__name__} - {e_fallback}")
raise e_fallback
else:
if not (self.fallback_client and self.fallback_model_name and is_content_filter_error):
# 如果不是内容过滤错误或者没有可用的备用API则记录主API的原始错误
print(f" 主 API 错误 ({type(e_primary).__name__}: {e_primary}), 且不满足备用条件或备用API未配置。")
raise e_primary
except APIError as e_primary_other:
print(f"❌ 主 API 调用最终失败 (非内容过滤,错误类型: {type(e_primary_other).__name__}): {e_primary_other}")
if self.fallback_client and self.fallback_model_name:
print(f" 主 API 失败,尝试切换到备用 API ({self.fallback_client.base_url})...")
try:
fallback_completion = await self._attempt_api_call(
client=self.fallback_client,
model_name=self.fallback_model_name,
messages=messages,
max_retries=self.max_retries_fallback,
api_name="备用",
**kwargs.copy()
)
print(f"✅ 备用 API 调用成功。")
return fallback_completion
except APIError as e_fallback_after_primary_fail:
print(f"❌ 备用 API 在主 API 失败后也调用失败: {type(e_fallback_after_primary_fail).__name__} - {e_fallback_after_primary_fail}")
raise e_fallback_after_primary_fail
else:
raise e_primary_other
async def close(self):
"""异步关闭主客户端和备用客户端 (如果存在)。"""
if not self._closed:
await self.primary_client.close()
if self.fallback_client:
await self.fallback_client.close()
self._closed = True
# print("AsyncFallbackOpenAIClient 已关闭。")
async def __aenter__(self):
if self._closed:
raise RuntimeError("AsyncFallbackOpenAIClient 不能在关闭后重新进入。请创建一个新实例。")
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()

View File

@@ -0,0 +1,25 @@
from typing import Any, Dict
def format_execution_result(result: Dict[str, Any]) -> str:
"""格式化执行结果为用户可读的反馈"""
feedback = []
if result['success']:
feedback.append("✅ 代码执行成功")
if result['output']:
feedback.append(f"📊 输出结果:\n{result['output']}")
if result.get('variables'):
feedback.append("📋 新生成的变量:")
for var_name, var_info in result['variables'].items():
feedback.append(f" - {var_name}: {var_info}")
else:
feedback.append("❌ 代码执行失败")
feedback.append(f"错误信息: {result['error']}")
if result['output']:
feedback.append(f"部分输出: {result['output']}")
return "\n".join(feedback)

91
utils/llm_helper.py Normal file
View File

@@ -0,0 +1,91 @@
# -*- coding: utf-8 -*-
"""
LLM调用辅助模块
"""
import asyncio
import yaml
from config.llm_config import LLMConfig
from utils.fallback_openai_client import AsyncFallbackOpenAIClient
class LLMCallError(RuntimeError):
"""Raised when the configured LLM backend cannot complete a request."""
class LLMHelper:
"""LLM调用辅助类支持同步和异步调用"""
def __init__(self, config: LLMConfig = None):
self.config = config or LLMConfig()
self.config.validate()
self.client = AsyncFallbackOpenAIClient(
primary_api_key=self.config.api_key,
primary_base_url=self.config.base_url,
primary_model_name=self.config.model
)
async def async_call(self, prompt: str, system_prompt: str = None, max_tokens: int = None, temperature: float = None) -> str:
"""异步调用LLM"""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
kwargs = {}
if max_tokens is not None:
kwargs['max_tokens'] = max_tokens
else:
kwargs['max_tokens'] = self.config.max_tokens
if temperature is not None:
kwargs['temperature'] = temperature
else:
kwargs['temperature'] = self.config.temperature
try:
response = await self.client.chat_completions_create(
messages=messages,
**kwargs
)
return response.choices[0].message.content
except Exception as e:
raise LLMCallError(f"LLM调用失败: {e}") from e
def call(self, prompt: str, system_prompt: str = None, max_tokens: int = None, temperature: float = None) -> str:
"""同步调用LLM"""
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
import nest_asyncio
nest_asyncio.apply()
return loop.run_until_complete(self.async_call(prompt, system_prompt, max_tokens, temperature))
def parse_yaml_response(self, response: str) -> dict:
"""解析YAML格式的响应"""
try:
# 提取```yaml和```之间的内容
if '```yaml' in response:
start = response.find('```yaml') + 7
end = response.find('```', start)
yaml_content = response[start:end].strip()
elif '```' in response:
start = response.find('```') + 3
end = response.find('```', start)
yaml_content = response[start:end].strip()
else:
yaml_content = response.strip()
return yaml.safe_load(yaml_content)
except Exception as e:
print(f"YAML解析失败: {e}")
print(f"原始响应: {response}")
return {}
async def close(self):
"""关闭客户端"""
await self.client.close()

4
webapp/__init__.py Normal file
View File

@@ -0,0 +1,4 @@
# -*- coding: utf-8 -*-
"""
Web application package for the data analysis platform.
"""

242
webapp/api.py Normal file
View File

@@ -0,0 +1,242 @@
# -*- coding: utf-8 -*-
"""
FastAPI application for the data analysis platform.
"""
import os
import re
import uuid
from typing import List, Optional
from fastapi import FastAPI, File, Form, HTTPException, Query, UploadFile
from fastapi.responses import FileResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field
from webapp.session_manager import SessionManager
from webapp.storage import Storage, utcnow_iso
from webapp.task_runner import TaskRunner
BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
RUNTIME_DIR = os.path.join(BASE_DIR, "runtime")
UPLOADS_DIR = os.path.join(RUNTIME_DIR, "uploads")
OUTPUTS_DIR = os.path.join(BASE_DIR, "outputs")
DB_PATH = os.path.join(RUNTIME_DIR, "analysis_platform.db")
os.makedirs(UPLOADS_DIR, exist_ok=True)
os.makedirs(OUTPUTS_DIR, exist_ok=True)
storage = Storage(DB_PATH)
session_manager = SessionManager(OUTPUTS_DIR)
task_runner = TaskRunner(
storage=storage,
uploads_dir=UPLOADS_DIR,
outputs_dir=OUTPUTS_DIR,
session_manager=session_manager,
max_workers=2,
)
app = FastAPI(title="Data Analysis Platform API")
STATIC_DIR = os.path.join(os.path.dirname(__file__), "static")
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
class CreateSessionRequest(BaseModel):
user_id: str = Field(..., min_length=1)
title: str = Field(..., min_length=1)
query: str = Field(..., min_length=1)
file_ids: List[str]
template_file_id: Optional[str] = None
class CreateTopicRequest(BaseModel):
user_id: str = Field(..., min_length=1)
query: str = Field(..., min_length=1)
def sanitize_filename(filename: str) -> str:
cleaned = re.sub(r"[^A-Za-z0-9._-]+", "_", filename).strip("._")
return cleaned or "upload.bin"
def ensure_session_access(session_id: str, user_id: str) -> dict:
session = storage.get_session(session_id, user_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
return session
def ensure_task_access(task_id: str, user_id: str) -> dict:
task = storage.get_task(task_id, user_id)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
return task
@app.get("/health")
def health():
return {"status": "ok"}
@app.get("/", response_class=HTMLResponse)
def index():
index_path = os.path.join(STATIC_DIR, "index.html")
with open(index_path, "r", encoding="utf-8") as f:
return HTMLResponse(f.read())
@app.post("/files/upload")
async def upload_files(
user_id: str = Form(...),
files: List[UploadFile] = File(...),
):
saved = []
user_dir = os.path.join(UPLOADS_DIR, user_id)
os.makedirs(user_dir, exist_ok=True)
for upload in files:
safe_name = sanitize_filename(upload.filename or "upload.bin")
stored_path = os.path.join(user_dir, f"{uuid.uuid4()}_{safe_name}")
with open(stored_path, "wb") as f:
while True:
chunk = await upload.read(1024 * 1024)
if not chunk:
break
f.write(chunk)
saved.append(storage.create_uploaded_file(user_id, upload.filename or safe_name, stored_path))
return {"files": saved}
@app.get("/files")
def list_files(user_id: str = Query(...)):
return {"files": storage.list_all_uploaded_files(user_id)}
@app.post("/sessions")
def create_session(request: CreateSessionRequest):
if not storage.list_uploaded_files(request.file_ids, request.user_id):
raise HTTPException(status_code=400, detail="No valid files found for session")
session = storage.create_session(
user_id=request.user_id,
title=request.title,
uploaded_file_ids=request.file_ids,
template_file_id=request.template_file_id,
)
task = storage.create_task(
session_id=session["id"],
user_id=request.user_id,
query=request.query,
uploaded_file_ids=request.file_ids,
template_file_id=request.template_file_id,
)
task_runner.submit(task["id"], request.user_id)
return {"session": session, "task": task}
@app.get("/sessions")
def list_sessions(user_id: str = Query(...)):
return {"sessions": storage.list_sessions(user_id)}
@app.get("/sessions/{session_id}")
def get_session(session_id: str, user_id: str = Query(...)):
session = ensure_session_access(session_id, user_id)
tasks = storage.list_session_tasks(session_id, user_id)
return {"session": session, "tasks": tasks}
@app.post("/sessions/{session_id}/topics")
def create_followup_topic(session_id: str, request: CreateTopicRequest):
session = ensure_session_access(session_id, request.user_id)
if session["status"] == "closed":
raise HTTPException(status_code=400, detail="Session is closed")
task = storage.create_task(
session_id=session_id,
user_id=request.user_id,
query=request.query,
uploaded_file_ids=session["uploaded_file_ids"],
template_file_id=session.get("template_file_id"),
)
task_runner.submit(task["id"], request.user_id)
return {"session": session, "task": task}
@app.post("/sessions/{session_id}/close")
def close_session(session_id: str, user_id: str = Query(...)):
session = ensure_session_access(session_id, user_id)
storage.update_session(session_id, status="closed", closed_at=utcnow_iso())
session_manager.close(session_id)
return {"session": storage.get_session(session_id, user_id)}
@app.get("/tasks")
def list_tasks(user_id: str = Query(...)):
return {"tasks": storage.list_tasks(user_id)}
@app.get("/tasks/{task_id}")
def get_task(task_id: str, user_id: str = Query(...)):
task = ensure_task_access(task_id, user_id)
return {"task": task}
@app.get("/tasks/{task_id}/report")
def get_task_report(task_id: str, user_id: str = Query(...)):
task = ensure_task_access(task_id, user_id)
report_path = task.get("report_file_path")
if not report_path or not os.path.exists(report_path):
raise HTTPException(status_code=404, detail="Report not available")
return FileResponse(report_path, media_type="text/markdown", filename=os.path.basename(report_path))
@app.get("/tasks/{task_id}/report/content")
def get_task_report_content(task_id: str, user_id: str = Query(...)):
task = ensure_task_access(task_id, user_id)
report_path = task.get("report_file_path")
if not report_path or not os.path.exists(report_path):
raise HTTPException(status_code=404, detail="Report not available")
with open(report_path, "r", encoding="utf-8") as f:
return {"content": f.read(), "filename": os.path.basename(report_path)}
@app.get("/tasks/{task_id}/artifacts")
def list_task_artifacts(task_id: str, user_id: str = Query(...)):
task = ensure_task_access(task_id, user_id)
session_output_dir = task.get("session_output_dir")
if not session_output_dir or not os.path.isdir(session_output_dir):
return {"artifacts": []}
artifacts = []
for name in sorted(os.listdir(session_output_dir)):
path = os.path.join(session_output_dir, name)
if not os.path.isfile(path):
continue
artifacts.append(
{
"name": name,
"size": os.path.getsize(path),
"is_image": name.lower().endswith((".png", ".jpg", ".jpeg", ".gif", ".webp")),
"url": f"/tasks/{task_id}/artifacts/{name}?user_id={user_id}",
}
)
return {"artifacts": artifacts}
@app.get("/tasks/{task_id}/artifacts/{artifact_name}")
def get_artifact(task_id: str, artifact_name: str, user_id: str = Query(...)):
task = ensure_task_access(task_id, user_id)
session_output_dir = task.get("session_output_dir")
if not session_output_dir:
raise HTTPException(status_code=404, detail="Artifact directory not available")
artifact_path = os.path.realpath(os.path.join(session_output_dir, artifact_name))
session_root = os.path.realpath(session_output_dir)
if artifact_path != session_root and not artifact_path.startswith(session_root + os.sep):
raise HTTPException(status_code=400, detail="Invalid artifact path")
if not os.path.exists(artifact_path):
raise HTTPException(status_code=404, detail="Artifact not found")
return FileResponse(artifact_path, filename=os.path.basename(artifact_path))

66
webapp/session_manager.py Normal file
View File

@@ -0,0 +1,66 @@
# -*- coding: utf-8 -*-
"""
In-memory registry for long-lived analysis sessions.
"""
import os
import threading
from dataclasses import dataclass, field
from typing import Dict, Optional
from config.llm_config import LLMConfig
from data_analysis_agent import DataAnalysisAgent
@dataclass
class RuntimeSession:
session_id: str
user_id: str
session_output_dir: str
uploaded_files: list[str]
template_path: Optional[str]
agent: DataAnalysisAgent
initialized: bool = False
lock: threading.Lock = field(default_factory=threading.Lock)
class SessionManager:
"""Keeps session-scoped agents alive across follow-up topics."""
def __init__(self, outputs_dir: str):
self.outputs_dir = os.path.abspath(outputs_dir)
self._sessions: Dict[str, RuntimeSession] = {}
self._lock = threading.Lock()
def get_or_create(
self,
session_id: str,
user_id: str,
session_output_dir: str,
uploaded_files: list[str],
template_path: Optional[str],
) -> RuntimeSession:
with self._lock:
runtime = self._sessions.get(session_id)
if runtime is None:
runtime = RuntimeSession(
session_id=session_id,
user_id=user_id,
session_output_dir=session_output_dir,
uploaded_files=uploaded_files,
template_path=template_path,
agent=DataAnalysisAgent(
llm_config=LLMConfig(),
output_dir=self.outputs_dir,
max_rounds=20,
force_max_rounds=False,
),
)
self._sessions[session_id] = runtime
return runtime
def close(self, session_id: str) -> None:
with self._lock:
runtime = self._sessions.pop(session_id, None)
if runtime:
runtime.agent.close_session()

299
webapp/static/app.js Normal file
View File

@@ -0,0 +1,299 @@
const state = {
userId: null,
files: [],
sessions: [],
currentSessionId: null,
currentTaskId: null,
pollTimer: null,
};
function ensureUserId() {
const key = "vibe_data_ana_user_id";
let userId = localStorage.getItem(key);
if (!userId) {
userId = `guest_${crypto.randomUUID()}`;
localStorage.setItem(key, userId);
}
state.userId = userId;
document.getElementById("user-id").textContent = userId;
}
async function api(path, options = {}) {
const response = await fetch(path, options);
const text = await response.text();
let data = {};
try {
data = text ? JSON.parse(text) : {};
} catch {
data = { raw: text };
}
if (!response.ok) {
throw new Error(data.detail || data.error || response.statusText);
}
return data;
}
function setText(id, value) {
document.getElementById(id).textContent = value || "";
}
function renderFiles() {
const fileList = document.getElementById("file-list");
const picker = document.getElementById("session-file-picker");
fileList.innerHTML = "";
picker.innerHTML = "";
if (!state.files.length) {
fileList.innerHTML = '<div class="empty">还没有上传文件。</div>';
picker.innerHTML = '<div class="empty">先上传文件后才能创建会话。</div>';
return;
}
state.files.forEach((file) => {
const item = document.createElement("div");
item.className = "file-item";
item.innerHTML = `<strong>${file.original_name}</strong><div class="hint">${file.id}</div>`;
fileList.appendChild(item);
const label = document.createElement("label");
label.className = "checkbox-item";
label.innerHTML = `
<input type="checkbox" value="${file.id}" />
<span>${file.original_name}</span>
`;
picker.appendChild(label);
});
}
function statusBadge(status) {
return `<span class="status ${status}">${status}</span>`;
}
function renderSessions() {
const container = document.getElementById("session-list");
container.innerHTML = "";
if (!state.sessions.length) {
container.innerHTML = '<div class="empty">暂无会话。</div>';
return;
}
state.sessions.forEach((session) => {
const card = document.createElement("button");
card.type = "button";
card.className = `session-card ${session.id === state.currentSessionId ? "active" : ""}`;
card.innerHTML = `
<div><strong>${session.title}</strong></div>
<div class="hint">${session.id}</div>
<div>${statusBadge(session.status)}</div>
`;
card.onclick = () => loadSessionDetail(session.id);
container.appendChild(card);
});
}
function renderTasks(tasks) {
const container = document.getElementById("task-list");
container.innerHTML = "";
if (!tasks.length) {
container.innerHTML = '<div class="empty">当前会话还没有专题任务。</div>';
return;
}
tasks.forEach((task) => {
const card = document.createElement("button");
card.type = "button";
card.className = `task-card ${task.id === state.currentTaskId ? "active" : ""}`;
card.innerHTML = `
<div><strong>${task.query}</strong></div>
<div class="hint">${task.created_at}</div>
<div>${statusBadge(task.status)}</div>
`;
card.onclick = () => loadTaskReport(task.id);
container.appendChild(card);
});
}
async function refreshFiles() {
const data = await api(`/files?user_id=${encodeURIComponent(state.userId)}`);
state.files = data.files || [];
renderFiles();
}
async function refreshSessions(selectSessionId = null) {
const data = await api(`/sessions?user_id=${encodeURIComponent(state.userId)}`);
state.sessions = data.sessions || [];
renderSessions();
if (selectSessionId) {
await loadSessionDetail(selectSessionId);
} else if (state.currentSessionId) {
const exists = state.sessions.some((session) => session.id === state.currentSessionId);
if (exists) {
await loadSessionDetail(state.currentSessionId, false);
}
}
}
async function loadSessionDetail(sessionId, renderReport = true) {
const data = await api(`/sessions/${sessionId}?user_id=${encodeURIComponent(state.userId)}`);
state.currentSessionId = sessionId;
document.getElementById("detail-title").textContent = data.session.title;
document.getElementById("detail-meta").textContent = `${data.session.id} · ${data.session.status}`;
renderSessions();
renderTasks(data.tasks || []);
const latestDoneTask = (data.tasks || []).slice().reverse().find((task) => task.status === "succeeded");
if (renderReport && latestDoneTask) {
await loadTaskReport(latestDoneTask.id);
} else if (!latestDoneTask) {
setText("report-title", "暂无已完成专题");
setText("report-content", "当前会话还没有可展示的报告。");
document.getElementById("artifact-gallery").innerHTML = "";
}
}
async function loadTaskReport(taskId) {
state.currentTaskId = taskId;
renderSessions();
const taskData = await api(`/tasks/${taskId}?user_id=${encodeURIComponent(state.userId)}`);
setText("report-title", taskData.task.query);
if (taskData.task.status !== "succeeded") {
setText("report-content", `当前任务状态为 ${taskData.task.status}\n错误信息:${taskData.task.error_message || "暂无"}`);
document.getElementById("artifact-gallery").innerHTML = "";
return;
}
const reportData = await api(`/tasks/${taskId}/report/content?user_id=${encodeURIComponent(state.userId)}`);
setText("report-content", reportData.content || "");
const artifactData = await api(`/tasks/${taskId}/artifacts?user_id=${encodeURIComponent(state.userId)}`);
renderArtifacts(artifactData.artifacts || []);
}
function renderArtifacts(artifacts) {
const gallery = document.getElementById("artifact-gallery");
gallery.innerHTML = "";
const images = artifacts.filter((item) => item.is_image);
if (!images.length) {
gallery.innerHTML = '<div class="empty">当前任务没有图片产物。</div>';
return;
}
images.forEach((artifact) => {
const card = document.createElement("div");
card.className = "artifact-card";
card.innerHTML = `
<img src="${artifact.url}" alt="${artifact.name}" />
<div><a href="${artifact.url}" target="_blank" rel="noreferrer">${artifact.name}</a></div>
`;
gallery.appendChild(card);
});
}
async function handleUpload(event) {
event.preventDefault();
const input = document.getElementById("upload-input");
if (!input.files.length) {
setText("upload-status", "请选择至少一个文件。");
return;
}
const formData = new FormData();
formData.append("user_id", state.userId);
Array.from(input.files).forEach((file) => formData.append("files", file));
setText("upload-status", "上传中...");
await api("/files/upload", { method: "POST", body: formData });
setText("upload-status", "文件上传完成。");
input.value = "";
await refreshFiles();
}
async function handleCreateSession(event) {
event.preventDefault();
const checked = Array.from(document.querySelectorAll("#session-file-picker input:checked"));
const fileIds = checked.map((item) => item.value);
if (!fileIds.length) {
setText("session-status", "请至少选择一个文件。");
return;
}
const payload = {
user_id: state.userId,
title: document.getElementById("session-title").value.trim(),
query: document.getElementById("session-query").value.trim(),
file_ids: fileIds,
};
setText("session-status", "会话创建中...");
const data = await api("/sessions", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify(payload),
});
setText("session-status", "会话已创建,正在执行首个专题。");
document.getElementById("session-form").reset();
await refreshSessions(data.session.id);
}
async function handleFollowup() {
if (!state.currentSessionId) {
setText("detail-meta", "请先选择一个会话。");
return;
}
const query = document.getElementById("followup-query").value.trim();
if (!query) {
return;
}
await api(`/sessions/${state.currentSessionId}/topics`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ user_id: state.userId, query }),
});
document.getElementById("followup-query").value = "";
await refreshSessions(state.currentSessionId);
}
async function handleCloseSession() {
if (!state.currentSessionId) {
return;
}
await api(`/sessions/${state.currentSessionId}/close?user_id=${encodeURIComponent(state.userId)}`, {
method: "POST",
});
await refreshSessions(state.currentSessionId);
}
function startPolling() {
if (state.pollTimer) {
clearInterval(state.pollTimer);
}
state.pollTimer = setInterval(() => {
refreshSessions().catch((error) => console.error(error));
}, 8000);
}
async function bootstrap() {
ensureUserId();
document.getElementById("upload-form").addEventListener("submit", (event) => {
handleUpload(event).catch((error) => setText("upload-status", error.message));
});
document.getElementById("session-form").addEventListener("submit", (event) => {
handleCreateSession(event).catch((error) => setText("session-status", error.message));
});
document.getElementById("submit-followup").onclick = () => {
handleFollowup().catch((error) => setText("detail-meta", error.message));
};
document.getElementById("close-session").onclick = () => {
handleCloseSession().catch((error) => setText("detail-meta", error.message));
};
document.getElementById("refresh-sessions").onclick = () => {
refreshSessions().catch((error) => console.error(error));
};
await refreshFiles();
await refreshSessions();
startPolling();
}
bootstrap().catch((error) => {
console.error(error);
setText("detail-meta", error.message);
});

88
webapp/static/index.html Normal file
View File

@@ -0,0 +1,88 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Vibe Data Analysis</title>
<link rel="stylesheet" href="/static/style.css" />
</head>
<body>
<main class="app-shell">
<section class="hero">
<div>
<p class="eyebrow">Vibe Data Analysis</p>
<h1>在线数据分析会话</h1>
<p class="subtle">
上传文件,发起分析,会话结束后继续追问新的专题,不中断当前上下文。
</p>
</div>
<div class="identity-card">
<span>当前访客标识</span>
<code id="user-id"></code>
</div>
</section>
<section class="panel-grid">
<section class="panel">
<h2>1. 上传文件</h2>
<form id="upload-form" class="stack-form">
<input id="upload-input" type="file" multiple />
<button type="submit">上传并登记</button>
</form>
<div id="upload-status" class="hint"></div>
<div id="file-list" class="file-list"></div>
</section>
<section class="panel">
<h2>2. 新建分析会话</h2>
<form id="session-form" class="stack-form">
<input id="session-title" type="text" placeholder="会话标题,例如:工单健康度" required />
<textarea id="session-query" rows="5" placeholder="输入首个分析专题,例如:请先整体评估工单健康度,并指出最需要关注的问题。" required></textarea>
<div id="session-file-picker" class="checkbox-list"></div>
<button type="submit">创建会话并开始分析</button>
</form>
<div id="session-status" class="hint"></div>
</section>
</section>
<section class="layout">
<aside class="sidebar panel">
<div class="sidebar-header">
<h2>会话列表</h2>
<button id="refresh-sessions" type="button">刷新</button>
</div>
<div id="session-list" class="session-list"></div>
</aside>
<section class="content panel">
<div class="content-header">
<div>
<h2 id="detail-title">未选择会话</h2>
<p id="detail-meta" class="hint">选择左侧会话查看分析结果与后续专题。</p>
</div>
<button id="close-session" type="button" class="ghost">结束当前会话</button>
</div>
<div class="followup-box">
<textarea id="followup-query" rows="4" placeholder="如果还有新的专题想继续分析,在这里输入。"></textarea>
<button id="submit-followup" type="button">继续分析该专题</button>
</div>
<div class="tasks-area">
<div class="tasks-column">
<h3>专题任务</h3>
<div id="task-list" class="task-list"></div>
</div>
<div class="report-column">
<h3>报告展示</h3>
<div id="report-title" class="report-title">暂无报告</div>
<pre id="report-content" class="report-content">选择一个已完成任务查看报告。</pre>
<div id="artifact-gallery" class="artifact-gallery"></div>
</div>
</div>
</section>
</section>
</main>
<script src="/static/app.js"></script>
</body>
</html>

278
webapp/static/style.css Normal file
View File

@@ -0,0 +1,278 @@
:root {
--bg: #f4efe7;
--panel: #fffaf2;
--line: #d8cdbd;
--text: #1f1a17;
--muted: #6e645a;
--accent: #b04a2f;
--accent-soft: #f2d5c4;
--success: #2f7d62;
--warning: #aa6a1f;
}
* {
box-sizing: border-box;
}
body {
margin: 0;
font-family: "IBM Plex Sans", "Noto Sans SC", sans-serif;
color: var(--text);
background:
radial-gradient(circle at top left, #fff7ec 0, transparent 28rem),
linear-gradient(180deg, #efe5d7 0%, var(--bg) 100%);
}
.app-shell {
max-width: 1440px;
margin: 0 auto;
padding: 24px;
}
.hero,
.panel,
.content,
.sidebar {
border: 1px solid var(--line);
background: rgba(255, 250, 242, 0.96);
backdrop-filter: blur(10px);
border-radius: 20px;
box-shadow: 0 12px 40px rgba(93, 67, 39, 0.08);
}
.hero {
display: flex;
justify-content: space-between;
gap: 24px;
padding: 24px 28px;
margin-bottom: 20px;
}
.eyebrow {
margin: 0 0 10px;
color: var(--accent);
text-transform: uppercase;
letter-spacing: 0.12em;
font-size: 12px;
}
h1,
h2,
h3,
p {
margin-top: 0;
}
.subtle,
.hint {
color: var(--muted);
}
.identity-card {
min-width: 220px;
padding: 16px;
border-radius: 16px;
background: linear-gradient(135deg, var(--accent-soft), #f8e7dc);
display: flex;
flex-direction: column;
gap: 8px;
}
.panel-grid,
.layout,
.tasks-area {
display: grid;
gap: 20px;
}
.panel-grid {
grid-template-columns: repeat(2, minmax(0, 1fr));
margin-bottom: 20px;
}
.layout {
grid-template-columns: 340px minmax(0, 1fr);
}
.tasks-area {
grid-template-columns: 320px minmax(0, 1fr);
}
.panel,
.content,
.sidebar {
padding: 20px;
}
.stack-form {
display: grid;
gap: 12px;
}
input,
textarea,
button {
font: inherit;
}
input,
textarea {
width: 100%;
padding: 12px 14px;
border-radius: 12px;
border: 1px solid var(--line);
background: #fffdf9;
}
button {
border: 0;
border-radius: 999px;
padding: 12px 18px;
background: var(--accent);
color: #fff;
cursor: pointer;
transition: transform 120ms ease, opacity 120ms ease;
}
button:hover {
transform: translateY(-1px);
opacity: 0.95;
}
button.ghost {
background: transparent;
color: var(--accent);
border: 1px solid var(--accent-soft);
}
.file-list,
.checkbox-list,
.session-list,
.task-list,
.artifact-gallery {
display: grid;
gap: 10px;
}
.file-item,
.session-card,
.task-card,
.artifact-card {
padding: 12px 14px;
border-radius: 14px;
border: 1px solid var(--line);
background: #fffdf8;
}
.checkbox-item {
display: flex;
align-items: center;
gap: 10px;
padding: 10px 12px;
border-radius: 12px;
border: 1px solid var(--line);
background: #fffdf8;
}
.sidebar-header,
.content-header {
display: flex;
justify-content: space-between;
align-items: start;
gap: 16px;
}
.followup-box {
margin: 20px 0;
display: grid;
gap: 12px;
}
.session-card.active,
.task-card.active {
border-color: var(--accent);
background: #fff3eb;
}
.status {
display: inline-flex;
padding: 4px 10px;
border-radius: 999px;
font-size: 12px;
font-weight: 600;
}
.status.queued {
background: #f1ead7;
color: #7b6114;
}
.status.running {
background: #e6edf9;
color: #1e5dab;
}
.status.succeeded,
.status.open {
background: #dff1ea;
color: var(--success);
}
.status.failed,
.status.closed {
background: #f8dfda;
color: #a03723;
}
.report-title {
margin-bottom: 10px;
color: var(--muted);
}
.report-content {
min-height: 360px;
max-height: 720px;
overflow: auto;
padding: 16px;
border-radius: 16px;
background: #1d1a18;
color: #f8f2ea;
line-height: 1.6;
white-space: pre-wrap;
}
.artifact-gallery {
margin-top: 16px;
grid-template-columns: repeat(auto-fill, minmax(220px, 1fr));
}
.artifact-card img {
width: 100%;
height: 180px;
object-fit: cover;
border-radius: 12px;
background: #ede3d7;
}
.artifact-card a {
color: var(--accent);
text-decoration: none;
}
.empty {
color: var(--muted);
font-style: italic;
}
@media (max-width: 1024px) {
.panel-grid,
.layout,
.tasks-area,
.hero {
grid-template-columns: 1fr;
}
.hero {
flex-direction: column;
}
}

311
webapp/storage.py Normal file
View File

@@ -0,0 +1,311 @@
# -*- coding: utf-8 -*-
"""
SQLite-backed storage for uploaded files and analysis tasks.
"""
import json
import os
import sqlite3
import threading
import uuid
from contextlib import contextmanager
from datetime import datetime
from typing import Any, Dict, Iterable, List, Optional
def utcnow_iso() -> str:
return datetime.utcnow().replace(microsecond=0).isoformat() + "Z"
class Storage:
"""Simple SQLite storage with thread-safe write operations."""
def __init__(self, db_path: str):
self.db_path = os.path.abspath(db_path)
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
self._write_lock = threading.Lock()
self.init_db()
@contextmanager
def _connect(self):
conn = sqlite3.connect(self.db_path, check_same_thread=False)
conn.row_factory = sqlite3.Row
try:
yield conn
conn.commit()
finally:
conn.close()
def init_db(self) -> None:
with self._connect() as conn:
conn.executescript(
"""
CREATE TABLE IF NOT EXISTS uploaded_files (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
original_name TEXT NOT NULL,
stored_path TEXT NOT NULL,
created_at TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS analysis_sessions (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
title TEXT NOT NULL,
status TEXT NOT NULL,
uploaded_file_ids TEXT NOT NULL,
template_file_id TEXT,
session_output_dir TEXT,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
closed_at TEXT
);
CREATE TABLE IF NOT EXISTS analysis_tasks (
id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
user_id TEXT NOT NULL,
query TEXT NOT NULL,
status TEXT NOT NULL,
uploaded_file_ids TEXT NOT NULL,
template_file_id TEXT,
session_output_dir TEXT,
report_file_path TEXT,
error_message TEXT,
created_at TEXT NOT NULL,
started_at TEXT,
finished_at TEXT
);
"""
)
def create_uploaded_file(
self, user_id: str, original_name: str, stored_path: str
) -> Dict[str, Any]:
record = {
"id": str(uuid.uuid4()),
"user_id": user_id,
"original_name": original_name,
"stored_path": os.path.abspath(stored_path),
"created_at": utcnow_iso(),
}
with self._write_lock, self._connect() as conn:
conn.execute(
"""
INSERT INTO uploaded_files (id, user_id, original_name, stored_path, created_at)
VALUES (:id, :user_id, :original_name, :stored_path, :created_at)
""",
record,
)
return record
def get_uploaded_file(self, file_id: str, user_id: str) -> Optional[Dict[str, Any]]:
with self._connect() as conn:
row = conn.execute(
"""
SELECT * FROM uploaded_files WHERE id = ? AND user_id = ?
""",
(file_id, user_id),
).fetchone()
return dict(row) if row else None
def list_uploaded_files(
self, file_ids: Iterable[str], user_id: str
) -> List[Dict[str, Any]]:
file_ids = list(file_ids)
if not file_ids:
return []
placeholders = ",".join("?" for _ in file_ids)
params = [*file_ids, user_id]
with self._connect() as conn:
rows = conn.execute(
f"""
SELECT * FROM uploaded_files
WHERE id IN ({placeholders}) AND user_id = ?
ORDER BY created_at ASC
""",
params,
).fetchall()
return [dict(row) for row in rows]
def list_all_uploaded_files(self, user_id: str) -> List[Dict[str, Any]]:
with self._connect() as conn:
rows = conn.execute(
"""
SELECT * FROM uploaded_files
WHERE user_id = ?
ORDER BY created_at DESC
""",
(user_id,),
).fetchall()
return [dict(row) for row in rows]
def create_task(
self,
session_id: str,
user_id: str,
query: str,
uploaded_file_ids: List[str],
template_file_id: Optional[str] = None,
) -> Dict[str, Any]:
record = {
"id": str(uuid.uuid4()),
"session_id": session_id,
"user_id": user_id,
"query": query,
"status": "queued",
"uploaded_file_ids": json.dumps(uploaded_file_ids, ensure_ascii=False),
"template_file_id": template_file_id,
"session_output_dir": None,
"report_file_path": None,
"error_message": None,
"created_at": utcnow_iso(),
"started_at": None,
"finished_at": None,
}
with self._write_lock, self._connect() as conn:
conn.execute(
"""
INSERT INTO analysis_tasks (
id, session_id, user_id, query, status, uploaded_file_ids, template_file_id,
session_output_dir, report_file_path, error_message,
created_at, started_at, finished_at
)
VALUES (
:id, :session_id, :user_id, :query, :status, :uploaded_file_ids, :template_file_id,
:session_output_dir, :report_file_path, :error_message,
:created_at, :started_at, :finished_at
)
""",
record,
)
return self.get_task(record["id"], user_id)
def get_task(self, task_id: str, user_id: str) -> Optional[Dict[str, Any]]:
with self._connect() as conn:
row = conn.execute(
"""
SELECT * FROM analysis_tasks WHERE id = ? AND user_id = ?
""",
(task_id, user_id),
).fetchone()
return self._normalize_task(row) if row else None
def list_tasks(self, user_id: str) -> List[Dict[str, Any]]:
with self._connect() as conn:
rows = conn.execute(
"""
SELECT * FROM analysis_tasks
WHERE user_id = ?
ORDER BY created_at DESC
""",
(user_id,),
).fetchall()
return [self._normalize_task(row) for row in rows]
def list_session_tasks(self, session_id: str, user_id: str) -> List[Dict[str, Any]]:
with self._connect() as conn:
rows = conn.execute(
"""
SELECT * FROM analysis_tasks
WHERE session_id = ? AND user_id = ?
ORDER BY created_at ASC
""",
(session_id, user_id),
).fetchall()
return [self._normalize_task(row) for row in rows]
def create_session(
self,
user_id: str,
title: str,
uploaded_file_ids: List[str],
template_file_id: Optional[str] = None,
) -> Dict[str, Any]:
now = utcnow_iso()
record = {
"id": str(uuid.uuid4()),
"user_id": user_id,
"title": title,
"status": "open",
"uploaded_file_ids": json.dumps(uploaded_file_ids, ensure_ascii=False),
"template_file_id": template_file_id,
"session_output_dir": None,
"created_at": now,
"updated_at": now,
"closed_at": None,
}
with self._write_lock, self._connect() as conn:
conn.execute(
"""
INSERT INTO analysis_sessions (
id, user_id, title, status, uploaded_file_ids, template_file_id,
session_output_dir, created_at, updated_at, closed_at
)
VALUES (
:id, :user_id, :title, :status, :uploaded_file_ids, :template_file_id,
:session_output_dir, :created_at, :updated_at, :closed_at
)
""",
record,
)
return self.get_session(record["id"], user_id)
def get_session(self, session_id: str, user_id: str) -> Optional[Dict[str, Any]]:
with self._connect() as conn:
row = conn.execute(
"""
SELECT * FROM analysis_sessions WHERE id = ? AND user_id = ?
""",
(session_id, user_id),
).fetchone()
return self._normalize_session(row) if row else None
def list_sessions(self, user_id: str) -> List[Dict[str, Any]]:
with self._connect() as conn:
rows = conn.execute(
"""
SELECT * FROM analysis_sessions
WHERE user_id = ?
ORDER BY updated_at DESC
""",
(user_id,),
).fetchall()
return [self._normalize_session(row) for row in rows]
def update_session(self, session_id: str, **fields: Any) -> None:
if not fields:
return
fields["updated_at"] = utcnow_iso()
assignments = ", ".join(f"{key} = :{key}" for key in fields.keys())
payload = dict(fields)
payload["id"] = session_id
with self._write_lock, self._connect() as conn:
conn.execute(
f"UPDATE analysis_sessions SET {assignments} WHERE id = :id",
payload,
)
def update_task(self, task_id: str, **fields: Any) -> None:
if not fields:
return
assignments = ", ".join(f"{key} = :{key}" for key in fields.keys())
payload = dict(fields)
payload["id"] = task_id
with self._write_lock, self._connect() as conn:
conn.execute(
f"UPDATE analysis_tasks SET {assignments} WHERE id = :id",
payload,
)
@staticmethod
def _normalize_task(row: sqlite3.Row) -> Dict[str, Any]:
task = dict(row)
task["uploaded_file_ids"] = json.loads(task["uploaded_file_ids"])
return task
@staticmethod
def _normalize_session(row: sqlite3.Row) -> Dict[str, Any]:
session = dict(row)
session["uploaded_file_ids"] = json.loads(session["uploaded_file_ids"])
return session

147
webapp/task_runner.py Normal file
View File

@@ -0,0 +1,147 @@
# -*- coding: utf-8 -*-
"""
Background task runner for analysis jobs.
"""
import os
import shutil
import threading
from concurrent.futures import ThreadPoolExecutor
from contextlib import redirect_stderr, redirect_stdout
from typing import Optional
from utils.create_session_dir import create_session_output_dir
from webapp.session_manager import SessionManager
from webapp.storage import Storage, utcnow_iso
class TaskRunner:
"""Runs analysis tasks in background worker threads."""
def __init__(
self,
storage: Storage,
uploads_dir: str,
outputs_dir: str,
session_manager: SessionManager,
max_workers: int = 2,
):
self.storage = storage
self.uploads_dir = os.path.abspath(uploads_dir)
self.outputs_dir = os.path.abspath(outputs_dir)
self.session_manager = session_manager
self._executor = ThreadPoolExecutor(max_workers=max_workers)
self._lock = threading.Lock()
self._submitted = set()
def submit(self, task_id: str, user_id: str) -> None:
with self._lock:
if task_id in self._submitted:
return
self._submitted.add(task_id)
self._executor.submit(self._run_task, task_id, user_id)
def _run_task(self, task_id: str, user_id: str) -> None:
try:
task = self.storage.get_task(task_id, user_id)
if not task:
return
session = self.storage.get_session(task["session_id"], user_id)
if not session:
return
uploaded_files = self.storage.list_uploaded_files(
task["uploaded_file_ids"], user_id
)
data_files = [item["stored_path"] for item in uploaded_files]
template_path = self._resolve_template_path(task, user_id)
session_output_dir = session.get("session_output_dir")
if not session_output_dir:
session_output_dir = create_session_output_dir(
self.outputs_dir, session["title"]
)
self.storage.update_session(
session["id"],
session_output_dir=session_output_dir,
)
session = self.storage.get_session(task["session_id"], user_id)
runtime = self.session_manager.get_or_create(
session_id=session["id"],
user_id=user_id,
session_output_dir=session_output_dir,
uploaded_files=data_files,
template_path=template_path,
)
self.storage.update_task(
task_id,
status="running",
session_output_dir=session_output_dir,
started_at=utcnow_iso(),
error_message=None,
)
self.storage.update_session(session["id"], status="running")
log_path = os.path.join(session_output_dir, "task.log")
with runtime.lock:
with open(log_path, "a", encoding="utf-8") as log_file:
log_file.write(
f"[{utcnow_iso()}] task started for session {session['id']}\n"
)
try:
with redirect_stdout(log_file), redirect_stderr(log_file):
result = runtime.agent.analyze(
user_input=task["query"],
files=data_files,
template_path=template_path,
session_output_dir=session_output_dir,
reset_context=not runtime.initialized,
keep_session_open=True,
)
runtime.initialized = True
except Exception as exc:
self.storage.update_task(
task_id,
status="failed",
error_message=str(exc),
finished_at=utcnow_iso(),
report_file_path=None,
)
self.storage.update_session(session["id"], status="open")
log_file.write(f"[{utcnow_iso()}] task failed: {exc}\n")
return
report_file_path = self._persist_task_report(
task_id, session_output_dir, result.get("report_file_path")
)
self.storage.update_task(
task_id,
status="succeeded",
report_file_path=report_file_path,
finished_at=utcnow_iso(),
error_message=None,
)
self.storage.update_session(session["id"], status="open")
finally:
with self._lock:
self._submitted.discard(task_id)
def _resolve_template_path(self, task: dict, user_id: str) -> Optional[str]:
template_file_id = task.get("template_file_id")
if not template_file_id:
return None
file_record = self.storage.get_uploaded_file(template_file_id, user_id)
return file_record["stored_path"] if file_record else None
@staticmethod
def _persist_task_report(
task_id: str, session_output_dir: str, current_report_path: Optional[str]
) -> Optional[str]:
if not current_report_path or not os.path.exists(current_report_path):
return current_report_path
task_report_path = os.path.join(session_output_dir, f"report_{task_id}.md")
if os.path.abspath(current_report_path) != os.path.abspath(task_report_path):
shutil.copyfile(current_report_path, task_report_path)
return task_report_path