Add web session analysis platform with follow-up topics
This commit is contained in:
9
.gitignore
vendored
Normal file
9
.gitignore
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
__pycache__/
|
||||
*.pyc
|
||||
.DS_Store
|
||||
.env
|
||||
.env copy
|
||||
outputs/
|
||||
runtime/
|
||||
*.log
|
||||
log.txt
|
||||
21
LICENSE
Normal file
21
LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 Data Analysis Agent Team
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
357
README.md
Normal file
357
README.md
Normal file
@@ -0,0 +1,357 @@
|
||||
# 数据分析智能体 (Data Analysis Agent)
|
||||
|
||||
🤖 **基于LLM的智能数据分析代理**
|
||||
|
||||
[](https://python.org)
|
||||
[](LICENSE)
|
||||
[](https://openai.com)
|
||||
|
||||
## 📋 项目简介
|
||||
|
||||

|
||||

|
||||
数据分析智能体是一个功能强大的Python工具,它结合了大语言模型(LLM)的理解能力和Python数据分析库的计算能力,能够:
|
||||
|
||||
- 🎯 **自然语言分析**:接受用户的自然语言需求,自动生成专业的数据分析代码
|
||||
- 📊 **智能可视化**:自动生成高质量的图表,支持中文显示,输出到专用目录
|
||||
- 🔄 **迭代优化**:基于执行结果自动调整分析策略,持续优化分析质量
|
||||
- 📝 **报告生成**:自动生成包含图表和分析结论的专业报告(Markdown + Word)
|
||||
- 🛡️ **安全执行**:在受限的环境中安全执行代码,支持常用的数据分析库
|
||||
|
||||
## 🏗️ 项目架构
|
||||
|
||||
```
|
||||
data_analysis_agent/
|
||||
├── 📁 config/ # 配置管理
|
||||
│ ├── __init__.py
|
||||
│ └── llm_config.py # LLM配置(API密钥、模型等)
|
||||
├── 📁 utils/ # 核心工具模块
|
||||
│ ├── code_executor.py # 安全的代码执行器
|
||||
│ ├── llm_helper.py # LLM调用辅助类
|
||||
│ ├── fallback_openai_client.py # 支持故障转移的OpenAI客户端
|
||||
│ ├── extract_code.py # 代码提取工具
|
||||
│ ├── format_execution_result.py # 执行结果格式化
|
||||
│ └── create_session_dir.py # 会话目录管理
|
||||
├── 📄 data_analysis_agent.py # 主智能体类
|
||||
├── 📄 prompts.py # 系统提示词模板
|
||||
├── 📄 main.py # 使用示例
|
||||
├── 📄 requirements.txt # 项目依赖
|
||||
├── 📄 .env # 环境变量配置
|
||||
└── 📁 outputs/ # 分析结果输出目录
|
||||
└── session_[时间戳]/ # 每次分析的独立会话目录
|
||||
├── *.png # 生成的图表
|
||||
├── 最终分析报告.md # Markdown报告
|
||||
└── 最终分析报告.docx # Word报告
|
||||
```
|
||||
|
||||
## 📊 数据分析流程图
|
||||
|
||||
使用Mermaid图表展示完整的数据分析流程:
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
A[用户输入自然语言需求] --> B[初始化智能体]
|
||||
B --> C[创建专用会话目录]
|
||||
C --> D[LLM理解需求并生成代码]
|
||||
D --> E[安全代码执行器执行]
|
||||
E --> F{执行是否成功?}
|
||||
F -->|失败| G[错误分析与修复]
|
||||
G --> D
|
||||
F -->|成功| H[结果格式化与存储]
|
||||
H --> I{是否需要更多分析?}
|
||||
I -->|是| J[基于当前结果继续分析]
|
||||
J --> D
|
||||
I -->|否| K[收集所有图表]
|
||||
K --> L[生成最终分析报告]
|
||||
L --> M[输出Markdown和Word报告]
|
||||
M --> N[分析完成]
|
||||
|
||||
style A fill:#e1f5fe
|
||||
style N fill:#c8e6c9
|
||||
style F fill:#fff3e0
|
||||
style I fill:#fff3e0
|
||||
```
|
||||
|
||||
## 🔄 智能体工作流程
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant User as 用户
|
||||
participant Agent as 数据分析智能体
|
||||
participant LLM as 语言模型
|
||||
participant Executor as 代码执行器
|
||||
participant Storage as 文件存储
|
||||
|
||||
User->>Agent: 提供数据文件和分析需求
|
||||
Agent->>Storage: 创建专用会话目录
|
||||
|
||||
loop 多轮分析循环
|
||||
Agent->>LLM: 发送分析需求和上下文
|
||||
LLM->>Agent: 返回分析代码和推理
|
||||
Agent->>Executor: 执行Python代码
|
||||
Executor->>Storage: 保存图表文件
|
||||
Executor->>Agent: 返回执行结果
|
||||
|
||||
alt 需要继续分析
|
||||
Agent->>LLM: 基于结果继续分析
|
||||
else 分析完成
|
||||
Agent->>LLM: 生成最终报告
|
||||
LLM->>Agent: 返回分析报告
|
||||
Agent->>Storage: 保存报告文件
|
||||
end
|
||||
end
|
||||
|
||||
Agent->>User: 返回完整分析结果
|
||||
```
|
||||
|
||||
## ✨ 核心特性
|
||||
|
||||
### 🧠 智能分析流程
|
||||
|
||||
- **多阶段分析**:数据探索 → 清洗检查 → 分析可视化 → 图片收集 → 报告生成
|
||||
- **错误自愈**:自动检测并修复常见错误(编码、列名、数据类型等)
|
||||
- **上下文保持**:Notebook环境中变量和状态在分析过程中持续保持
|
||||
|
||||
### 📋 多格式报告
|
||||
|
||||
- **Markdown报告**:结构化的分析报告,包含图表引用
|
||||
- **Word文档**:专业的文档格式,便于分享和打印
|
||||
- **图片集成**:报告中自动引用生成的图表
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
### 1. 环境准备
|
||||
|
||||
```bash
|
||||
# 克隆项目
|
||||
git clone https://github.com/li-xiu-qi/data_analysis_agent.git
|
||||
|
||||
cd data_analysis_agent
|
||||
|
||||
# 安装依赖
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 2. 配置API密钥
|
||||
|
||||
创建`.env`文件:
|
||||
|
||||
```bash
|
||||
# OpenAI API配置
|
||||
OPENAI_API_KEY=your_api_key_here
|
||||
OPENAI_BASE_URL=https://api.openai.com/v1
|
||||
OPENAI_MODEL=gpt-4
|
||||
|
||||
# 或者使用兼容的API(如火山引擎)
|
||||
# OPENAI_BASE_URL=https://ark.cn-beijing.volces.com/api/v3
|
||||
# OPENAI_MODEL=deepseek-v3-250324
|
||||
```
|
||||
|
||||
### 3. 基本使用
|
||||
|
||||
```python
|
||||
from data_analysis_agent import DataAnalysisAgent
|
||||
from config.llm_config import LLMConfig
|
||||
|
||||
# 初始化智能体
|
||||
llm_config = LLMConfig()
|
||||
agent = DataAnalysisAgent(llm_config)
|
||||
|
||||
# 开始分析
|
||||
files = ["your_data.csv"]
|
||||
report = agent.analyze(
|
||||
user_input="分析销售数据,生成趋势图表和关键指标",
|
||||
files=files
|
||||
)
|
||||
|
||||
print(report)
|
||||
```
|
||||
|
||||
```python
|
||||
# 自定义配置
|
||||
agent = DataAnalysisAgent(
|
||||
llm_config=llm_config,
|
||||
output_dir="custom_outputs", # 自定义输出目录
|
||||
max_rounds=30 # 增加最大分析轮数
|
||||
)
|
||||
|
||||
# 使用便捷函数
|
||||
from data_analysis_agent import quick_analysis
|
||||
|
||||
report = quick_analysis(
|
||||
query="分析用户行为数据,重点关注转化率",
|
||||
files=["user_behavior.csv"],
|
||||
max_rounds=15
|
||||
)
|
||||
```
|
||||
|
||||
## 📊 使用示例
|
||||
|
||||
以下是分析贵州茅台财务数据的完整示例:
|
||||
|
||||
```python
|
||||
# 示例:茅台财务分析
|
||||
files = ["贵州茅台利润表.csv"]
|
||||
report = agent.analyze(
|
||||
user_input="基于贵州茅台的数据,输出五个重要的统计指标,并绘制相关图表。最后生成汇报给我。",
|
||||
files=files
|
||||
)
|
||||
```
|
||||
|
||||
**生成的分析内容包括:**
|
||||
|
||||
- 📈 营业总收入趋势图
|
||||
- 💰 净利润率变化分析
|
||||
- 📊 利润构成分析图表
|
||||
- 💵 每股收益变化趋势
|
||||
- 📋 营业成本占比分析
|
||||
- 📄 综合分析报告
|
||||
|
||||
## 🎨 流程可视化
|
||||
|
||||
### 📊 分析过程状态图
|
||||
|
||||
```mermaid
|
||||
stateDiagram-v2
|
||||
[*] --> 数据加载
|
||||
数据加载 --> 数据探索: 成功加载
|
||||
数据加载 --> 编码修复: 编码错误
|
||||
编码修复 --> 数据探索: 修复完成
|
||||
|
||||
数据探索 --> 数据清洗: 探索完成
|
||||
数据清洗 --> 统计分析: 清洗完成
|
||||
统计分析 --> 可视化生成: 分析完成
|
||||
|
||||
可视化生成 --> 图表保存: 图表生成
|
||||
图表保存 --> 结果评估: 保存完成
|
||||
|
||||
结果评估 --> 继续分析: 需要更多分析
|
||||
结果评估 --> 报告生成: 分析充分
|
||||
继续分析 --> 统计分析
|
||||
|
||||
报告生成 --> [*]: 完成
|
||||
```
|
||||
|
||||
## 🔧 配置选项
|
||||
|
||||
### LLM配置
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class LLMConfig:
|
||||
provider: str = "openai"
|
||||
api_key: str = os.environ.get("OPENAI_API_KEY", "")
|
||||
base_url: str = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
||||
model: str = os.environ.get("OPENAI_MODEL", "gpt-4")
|
||||
max_tokens: int = 4000
|
||||
temperature: float = 0.1
|
||||
```
|
||||
|
||||
### 执行器配置
|
||||
|
||||
```python
|
||||
# 允许的库列表
|
||||
ALLOWED_IMPORTS = {
|
||||
'pandas', 'numpy', 'matplotlib', 'duckdb',
|
||||
'scipy', 'sklearn', 'plotly', 'requests',
|
||||
'os', 'json', 'datetime', 're', 'pathlib'
|
||||
}
|
||||
```
|
||||
|
||||
## 🎯 最佳实践
|
||||
|
||||
### 1. 数据准备
|
||||
|
||||
- ✅ 使用CSV格式,支持UTF-8/GBK编码
|
||||
- ✅ 确保列名清晰、无特殊字符
|
||||
- ✅ 数据量适中(建议<100MB)
|
||||
|
||||
### 2. 查询编写
|
||||
|
||||
- ✅ 使用清晰的中文描述分析需求
|
||||
- ✅ 指定想要的图表类型和关键指标
|
||||
- ✅ 明确分析的目标和重点
|
||||
|
||||
### 3. 结果解读
|
||||
|
||||
- ✅ 检查生成的图表是否符合预期
|
||||
- ✅ 阅读分析报告中的关键发现
|
||||
- ✅ 根据需要调整查询重新分析
|
||||
|
||||
## 🚨 注意事项
|
||||
|
||||
### 安全限制
|
||||
|
||||
- 🔒 仅支持预定义的数据分析库
|
||||
- 🔒 不允许文件系统操作(除图片保存)
|
||||
- 🔒 不支持网络请求(除LLM调用)
|
||||
|
||||
### 性能考虑
|
||||
|
||||
- ⚡ 大数据集可能导致分析时间较长
|
||||
- ⚡ 复杂分析任务可能需要多轮交互
|
||||
- ⚡ API调用频率受到模型限制
|
||||
|
||||
### 兼容性
|
||||
|
||||
- 🐍 Python 3.8+
|
||||
- 📊 支持pandas兼容的数据格式
|
||||
- 🖼️ 需要matplotlib中文字体支持
|
||||
|
||||
## 🐛 故障排除
|
||||
|
||||
### 常见问题
|
||||
|
||||
**Q: 图表中文显示为方框?**
|
||||
A: 系统会自动检测并使用可用的中文字体(macOS: Hiragino Sans GB, Songti SC等;Windows: SimHei等)。
|
||||
|
||||
**Q: API调用失败?**
|
||||
A: 检查`.env`文件中的API密钥和端点配置,确保网络连接正常。
|
||||
|
||||
**Q: 数据加载错误?**
|
||||
A: 检查文件路径和编码格式,支持UTF-8、GBK等常见编码。
|
||||
|
||||
**Q: 分析结果不准确?**
|
||||
A: 尝试提供更详细的分析需求,或检查原始数据质量。
|
||||
|
||||
**Q: Mermaid流程图无法正常显示?**
|
||||
A: 确保在支持Mermaid的环境中查看(如GitHub、Typora、VS Code预览等)。如果在本地查看,推荐使用支持Mermaid的Markdown编辑器。
|
||||
|
||||
**Q: 如何自定义流程图样式?**
|
||||
A: 可以在Mermaid代码块中添加样式定义,或使用不同的图表类型(graph、flowchart、sequenceDiagram等)来满足不同的展示需求。
|
||||
|
||||
### 错误日志
|
||||
|
||||
分析过程中的错误信息会保存在会话目录中,便于调试和优化。
|
||||
|
||||
## 🤝 贡献指南
|
||||
|
||||
欢迎贡献代码和改进建议!
|
||||
|
||||
1. Fork 项目
|
||||
2. 创建功能分支
|
||||
3. 提交更改
|
||||
4. 推送到分支
|
||||
5. 创建Pull Request
|
||||
|
||||
## 📄 许可证
|
||||
|
||||
本项目基于MIT许可证开源。详见[LICENSE](LICENSE)文件。
|
||||
|
||||
## 🔄 更新日志
|
||||
|
||||
### v1.0.0
|
||||
|
||||
- ✨ 初始版本发布
|
||||
- 🎯 支持自然语言数据分析
|
||||
- 📊 集成matplotlib图表生成
|
||||
- 📝 自动报告生成功能
|
||||
- 🔒 安全的代码执行环境
|
||||
|
||||
---
|
||||
|
||||
<div align="center">
|
||||
|
||||
**🚀 让数据分析变得更智能、更简单!**
|
||||
|
||||
</div>
|
||||
5901
UB IOV Support_TR.csv
Executable file
5901
UB IOV Support_TR.csv
Executable file
File diff suppressed because it is too large
Load Diff
54
__init__.py
Normal file
54
__init__.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Data Analysis Agent Package
|
||||
|
||||
一个基于LLM的智能数据分析代理,专门为Jupyter Notebook环境设计。
|
||||
"""
|
||||
|
||||
from .core.notebook_agent import NotebookAgent
|
||||
from .config.llm_config import LLMConfig
|
||||
from .utils.code_executor import CodeExecutor
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "Data Analysis Agent Team"
|
||||
|
||||
# 主要导出类
|
||||
__all__ = [
|
||||
"NotebookAgent",
|
||||
"LLMConfig",
|
||||
"CodeExecutor",
|
||||
]
|
||||
|
||||
# 便捷函数
|
||||
def create_agent(config=None, output_dir="outputs", max_rounds=20, session_dir=None):
|
||||
"""
|
||||
创建一个数据分析智能体实例
|
||||
|
||||
Args:
|
||||
config: LLM配置,如果为None则使用默认配置
|
||||
output_dir: 输出目录
|
||||
max_rounds: 最大分析轮数
|
||||
session_dir: 指定会话目录(可选)
|
||||
|
||||
Returns:
|
||||
NotebookAgent: 智能体实例
|
||||
"""
|
||||
if config is None:
|
||||
config = LLMConfig()
|
||||
return NotebookAgent(config=config, output_dir=output_dir, max_rounds=max_rounds, session_dir=session_dir)
|
||||
|
||||
def quick_analysis(query, files=None, output_dir="outputs", max_rounds=10):
|
||||
"""
|
||||
快速数据分析函数
|
||||
|
||||
Args:
|
||||
query: 分析需求(自然语言)
|
||||
files: 数据文件路径列表
|
||||
output_dir: 输出目录
|
||||
max_rounds: 最大分析轮数
|
||||
|
||||
Returns:
|
||||
dict: 分析结果
|
||||
"""
|
||||
agent = create_agent(output_dir=output_dir, max_rounds=max_rounds)
|
||||
return agent.analyze(query, files)
|
||||
8
config/__init__.py
Normal file
8
config/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
配置模块
|
||||
"""
|
||||
|
||||
from .llm_config import LLMConfig
|
||||
|
||||
__all__ = ['LLMConfig']
|
||||
44
config/llm_config.py
Normal file
44
config/llm_config.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
配置管理模块
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMConfig:
|
||||
"""LLM配置"""
|
||||
|
||||
provider: str = "openai" # openai, anthropic, etc.
|
||||
api_key: str = os.environ.get("OPENAI_API_KEY", "")
|
||||
base_url: str = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
||||
model: str = os.environ.get("OPENAI_MODEL", "gpt-4o-mini")
|
||||
temperature: float = 0.3
|
||||
max_tokens: int = 131072
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "LLMConfig":
|
||||
"""从字典创建配置"""
|
||||
return cls(**data)
|
||||
|
||||
def validate(self) -> bool:
|
||||
"""验证配置有效性"""
|
||||
if not self.api_key:
|
||||
raise ValueError("OPENAI_API_KEY is required")
|
||||
if not self.base_url:
|
||||
raise ValueError("OPENAI_BASE_URL is required")
|
||||
if not self.model:
|
||||
raise ValueError("OPENAI_MODEL is required")
|
||||
return True
|
||||
628
data_analysis_agent.py
Normal file
628
data_analysis_agent.py
Normal file
@@ -0,0 +1,628 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
简化的 Notebook 数据分析智能体
|
||||
仅包含用户和助手两个角
|
||||
2. 图片必须保存到指定的会话目录中,输出绝对路径,禁止使用plt.show()
|
||||
3. 表格输出控制:超过15行只显示前5行和后5行
|
||||
4. 强制使用SimHei字体:plt.rcParams['font.sans-serif'] = ['SimHei']
|
||||
5. 输出格式严格使用YAML共享上下文的单轮对话模式
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import yaml
|
||||
from typing import Dict, Any, List, Optional
|
||||
from utils.create_session_dir import create_session_output_dir
|
||||
from utils.format_execution_result import format_execution_result
|
||||
from utils.extract_code import extract_code_from_response
|
||||
from utils.data_loader import load_and_profile_data
|
||||
from utils.llm_helper import LLMHelper, LLMCallError
|
||||
from utils.execution_session_client import (
|
||||
ExecutionSessionClient,
|
||||
WorkerSessionError,
|
||||
WorkerTimeoutError,
|
||||
)
|
||||
from config.llm_config import LLMConfig
|
||||
from prompts import data_analysis_system_prompt, final_report_system_prompt
|
||||
|
||||
|
||||
class DataAnalysisAgent:
|
||||
"""
|
||||
数据分析智能体
|
||||
|
||||
职责:
|
||||
- 接收用户自然语言需求
|
||||
- 生成Python分析代码
|
||||
- 执行代码并收集结果
|
||||
- 基于执行结果继续生成后续分析代码
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_config: LLMConfig = None,
|
||||
output_dir: str = "outputs",
|
||||
max_rounds: int = 20,
|
||||
force_max_rounds: bool = False,
|
||||
):
|
||||
"""
|
||||
初始化智能体
|
||||
|
||||
Args:
|
||||
config: LLM配置
|
||||
output_dir: 输出目录
|
||||
max_rounds: 最大对话轮数
|
||||
force_max_rounds: 是否强制运行到最大轮数(忽略AI的完成信号)
|
||||
"""
|
||||
self.config = llm_config or LLMConfig()
|
||||
self.llm = LLMHelper(self.config)
|
||||
self.base_output_dir = output_dir
|
||||
self.max_rounds = max_rounds
|
||||
self.force_max_rounds = force_max_rounds
|
||||
# 对话历史和上下文
|
||||
self.conversation_history = []
|
||||
self.analysis_results = []
|
||||
self.current_round = 0
|
||||
self.session_output_dir = None
|
||||
self.executor = None
|
||||
self.data_profile = "" # 存储数据画像
|
||||
self.fatal_error = ""
|
||||
self.fatal_error_stage = ""
|
||||
self.session_files = []
|
||||
self.template_content = "未提供特定模板,请根据数据画像自主发挥。"
|
||||
|
||||
def _process_response(self, response: str) -> Dict[str, Any]:
|
||||
"""
|
||||
统一处理LLM响应,判断行动类型并执行相应操作
|
||||
|
||||
Args:
|
||||
response: LLM的响应内容
|
||||
|
||||
Returns:
|
||||
处理结果字典
|
||||
"""
|
||||
try:
|
||||
yaml_data = self.llm.parse_yaml_response(response)
|
||||
action = yaml_data.get("action", "generate_code")
|
||||
|
||||
print(f"🎯 检测到动作: {action}")
|
||||
|
||||
if action == "analysis_complete":
|
||||
return self._handle_analysis_complete(response, yaml_data)
|
||||
elif action == "collect_figures":
|
||||
return self._handle_collect_figures(response, yaml_data)
|
||||
elif action == "generate_code":
|
||||
return self._handle_generate_code(response, yaml_data)
|
||||
else:
|
||||
print(f"⚠️ 未知动作类型: {action},按generate_code处理")
|
||||
return self._handle_generate_code(response, yaml_data)
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ 解析响应失败: {str(e)},按generate_code处理")
|
||||
return self._handle_generate_code(response, {})
|
||||
|
||||
def _handle_analysis_complete(
|
||||
self, response: str, yaml_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""处理分析完成动作"""
|
||||
print("✅ 分析任务完成")
|
||||
final_report = yaml_data.get("final_report", "分析完成,无最终报告")
|
||||
return {
|
||||
"action": "analysis_complete",
|
||||
"final_report": final_report,
|
||||
"response": response,
|
||||
"continue": False,
|
||||
}
|
||||
|
||||
def _handle_collect_figures(
|
||||
self, response: str, yaml_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""处理图片收集动作"""
|
||||
print("📊 开始收集图片")
|
||||
figures_to_collect = yaml_data.get("figures_to_collect", [])
|
||||
|
||||
collected_figures = []
|
||||
|
||||
for figure_info in figures_to_collect:
|
||||
figure_number = figure_info.get("figure_number", "未知")
|
||||
# 确保figure_number不为None时才用于文件名
|
||||
if figure_number != "未知":
|
||||
default_filename = f"figure_{figure_number}.png"
|
||||
else:
|
||||
default_filename = "figure_unknown.png"
|
||||
filename = figure_info.get("filename", default_filename)
|
||||
file_path = figure_info.get("file_path", "") # 获取具体的文件路径
|
||||
description = figure_info.get("description", "")
|
||||
analysis = figure_info.get("analysis", "")
|
||||
|
||||
print(f"📈 收集图片 {figure_number}: {filename}")
|
||||
print(f" 📂 路径: {file_path}")
|
||||
print(f" 📝 描述: {description}")
|
||||
print(f" 🔍 分析: {analysis}")
|
||||
|
||||
# 验证文件是否存在
|
||||
# 只有文件真正存在时才加入列表,防止报告出现裂图
|
||||
if file_path and os.path.exists(file_path):
|
||||
print(f" ✅ 文件存在: {file_path}")
|
||||
# 记录图片信息
|
||||
collected_figures.append(
|
||||
{
|
||||
"figure_number": figure_number,
|
||||
"filename": filename,
|
||||
"file_path": file_path,
|
||||
"description": description,
|
||||
"analysis": analysis,
|
||||
}
|
||||
)
|
||||
else:
|
||||
if file_path:
|
||||
print(f" ⚠️ 文件不存在: {file_path}")
|
||||
else:
|
||||
print(f" ⚠️ 未提供文件路径")
|
||||
|
||||
return {
|
||||
"action": "collect_figures",
|
||||
"collected_figures": collected_figures,
|
||||
"response": response,
|
||||
"continue": True,
|
||||
}
|
||||
|
||||
def _handle_generate_code(
|
||||
self, response: str, yaml_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""处理代码生成和执行动作"""
|
||||
# 从YAML数据中获取代码(更准确)
|
||||
code = yaml_data.get("code", "")
|
||||
|
||||
# 如果YAML中没有代码,尝试从响应中提取
|
||||
if not code:
|
||||
code = extract_code_from_response(response)
|
||||
|
||||
# 二次清洗:防止YAML中解析出的code包含markdown标记
|
||||
if code:
|
||||
code = code.strip()
|
||||
if code.startswith("```"):
|
||||
import re
|
||||
# 去除开头的 ```python 或 ```
|
||||
code = re.sub(r"^```[a-zA-Z]*\n", "", code)
|
||||
# 去除结尾的 ```
|
||||
code = re.sub(r"\n```$", "", code)
|
||||
code = code.strip()
|
||||
|
||||
if code:
|
||||
print(f"🔧 执行代码:\n{code}")
|
||||
print("-" * 40)
|
||||
|
||||
# 执行代码
|
||||
result = self.executor.execute_code(code)
|
||||
|
||||
# 格式化执行结果
|
||||
feedback = format_execution_result(result)
|
||||
print(f"📋 执行反馈:\n{feedback}")
|
||||
|
||||
return {
|
||||
"action": "generate_code",
|
||||
"code": code,
|
||||
"result": result,
|
||||
"feedback": feedback,
|
||||
"response": response,
|
||||
"continue": True,
|
||||
}
|
||||
else:
|
||||
# 如果没有代码,说明LLM响应格式有问题,需要重新生成
|
||||
print("⚠️ 未从响应中提取到可执行代码,要求LLM重新生成")
|
||||
return {
|
||||
"action": "invalid_response",
|
||||
"error": "响应中缺少可执行代码",
|
||||
"response": response,
|
||||
"continue": True,
|
||||
}
|
||||
|
||||
def analyze(
|
||||
self,
|
||||
user_input: str,
|
||||
files: List[str] = None,
|
||||
template_path: str = None,
|
||||
session_output_dir: str = None,
|
||||
reset_context: bool = True,
|
||||
keep_session_open: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
开始分析流程
|
||||
|
||||
Args:
|
||||
user_input: 用户的自然语言需求
|
||||
files: 数据文件路径列表
|
||||
template_path: 参考模板路径(可选)
|
||||
session_output_dir: 指定的会话输出目录(可选)
|
||||
|
||||
Returns:
|
||||
分析结果字典
|
||||
"""
|
||||
files = files or []
|
||||
self.current_round = 0
|
||||
self.fatal_error = ""
|
||||
self.fatal_error_stage = ""
|
||||
if reset_context or not self.executor:
|
||||
self.conversation_history = []
|
||||
self.analysis_results = []
|
||||
self._initialize_session(user_input, files, template_path, session_output_dir)
|
||||
initial_prompt = self._build_initial_user_prompt(user_input, files)
|
||||
else:
|
||||
self._validate_followup_files(files)
|
||||
if template_path:
|
||||
self.template_content = self._load_template_content(template_path)
|
||||
initial_prompt = self._build_followup_user_prompt(user_input)
|
||||
|
||||
print(f"🚀 开始数据分析任务")
|
||||
print(f"📝 用户需求: {user_input}")
|
||||
if files:
|
||||
print(f"📁 数据文件: {', '.join(files)}")
|
||||
if template_path:
|
||||
print(f"📄 参考模板: {template_path}")
|
||||
print(f"📂 输出目录: {self.session_output_dir}")
|
||||
print(f"🔢 最大轮数: {self.max_rounds}")
|
||||
|
||||
# 添加到对话历史
|
||||
self.conversation_history.append({"role": "user", "content": initial_prompt})
|
||||
|
||||
try:
|
||||
while self.current_round < self.max_rounds:
|
||||
self.current_round += 1
|
||||
print(f"\n🔄 第 {self.current_round} 轮分析")
|
||||
|
||||
try:
|
||||
# 获取当前执行环境的变量信息
|
||||
notebook_variables = self.executor.get_environment_info()
|
||||
|
||||
# 格式化系统提示词,填入动态变量
|
||||
# 注意:prompts.py 中的模板现在需要填充三个变量
|
||||
formatted_system_prompt = data_analysis_system_prompt.format(
|
||||
notebook_variables=notebook_variables,
|
||||
user_input=user_input,
|
||||
template_content=self.template_content
|
||||
)
|
||||
|
||||
response = self.llm.call(
|
||||
prompt=self._build_conversation_prompt(),
|
||||
system_prompt=formatted_system_prompt,
|
||||
)
|
||||
|
||||
print(f"🤖 助手响应:\n{response}")
|
||||
|
||||
# 使用统一的响应处理方法
|
||||
process_result = self._process_response(response)
|
||||
|
||||
# 根据处理结果决定是否继续(仅在非强制模式下)
|
||||
if not self.force_max_rounds and not process_result.get(
|
||||
"continue", True
|
||||
):
|
||||
print(f"\n✅ 分析完成!")
|
||||
break
|
||||
|
||||
# 添加到对话历史
|
||||
self.conversation_history.append(
|
||||
{"role": "assistant", "content": response}
|
||||
)
|
||||
|
||||
# 根据动作类型添加不同的反馈
|
||||
if process_result["action"] == "generate_code":
|
||||
feedback = process_result.get("feedback", "")
|
||||
self.conversation_history.append(
|
||||
{"role": "user", "content": f"代码执行反馈:\n{feedback}"}
|
||||
)
|
||||
|
||||
# 记录分析结果
|
||||
self.analysis_results.append(
|
||||
{
|
||||
"round": self.current_round,
|
||||
"code": process_result.get("code", ""),
|
||||
"result": process_result.get("result", {}),
|
||||
"response": response,
|
||||
}
|
||||
)
|
||||
elif process_result["action"] == "collect_figures":
|
||||
# 记录图片收集结果
|
||||
collected_figures = process_result.get("collected_figures", [])
|
||||
missing_figures = process_result.get("missing_figures", [])
|
||||
|
||||
feedback = f"已收集 {len(collected_figures)} 个有效图片及其分析。"
|
||||
if missing_figures:
|
||||
feedback += f"\n⚠️ 以下图片未找到,请检查代码是否成功保存了这些图片: {missing_figures}"
|
||||
|
||||
self.conversation_history.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"图片收集反馈:\n{feedback}\n请继续下一步分析。",
|
||||
}
|
||||
)
|
||||
|
||||
# 记录到分析结果中
|
||||
self.analysis_results.append(
|
||||
{
|
||||
"round": self.current_round,
|
||||
"action": "collect_figures",
|
||||
"collected_figures": collected_figures,
|
||||
"missing_figures": missing_figures,
|
||||
"response": response,
|
||||
}
|
||||
)
|
||||
|
||||
except LLMCallError as e:
|
||||
error_msg = str(e)
|
||||
self.fatal_error = error_msg
|
||||
self.fatal_error_stage = "模型调用"
|
||||
print(f"❌ {error_msg}")
|
||||
break
|
||||
except WorkerTimeoutError as e:
|
||||
error_msg = str(e)
|
||||
self.fatal_error = error_msg
|
||||
self.fatal_error_stage = "代码执行超时"
|
||||
print(f"❌ {error_msg}")
|
||||
break
|
||||
except WorkerSessionError as e:
|
||||
error_msg = str(e)
|
||||
self.fatal_error = error_msg
|
||||
self.fatal_error_stage = "执行子进程异常"
|
||||
print(f"❌ {error_msg}")
|
||||
break
|
||||
except Exception as e:
|
||||
error_msg = f"分析流程错误: {str(e)}"
|
||||
print(f"❌ {error_msg}")
|
||||
self.conversation_history.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"发生错误: {error_msg},请重新生成代码。",
|
||||
}
|
||||
)
|
||||
|
||||
# 生成最终总结
|
||||
if self.current_round >= self.max_rounds:
|
||||
print(f"\n⚠️ 已达到最大轮数 ({self.max_rounds}),分析结束")
|
||||
|
||||
return self._generate_final_report()
|
||||
finally:
|
||||
if self.executor and not keep_session_open:
|
||||
self.executor.close()
|
||||
self.executor = None
|
||||
|
||||
def _build_conversation_prompt(self) -> str:
|
||||
"""构建对话提示词"""
|
||||
prompt_parts = []
|
||||
|
||||
for msg in self.conversation_history:
|
||||
role = msg["role"]
|
||||
content = msg["content"]
|
||||
if role == "user":
|
||||
prompt_parts.append(f"用户: {content}")
|
||||
else:
|
||||
prompt_parts.append(f"助手: {content}")
|
||||
|
||||
return "\n\n".join(prompt_parts)
|
||||
|
||||
def _generate_final_report(self) -> Dict[str, Any]:
|
||||
"""生成最终分析报告"""
|
||||
if self.fatal_error:
|
||||
final_report_content = (
|
||||
"# 分析任务失败\n\n"
|
||||
f"- 失败阶段: 第 {self.current_round} 轮{self.fatal_error_stage or '分析流程'}\n"
|
||||
f"- 错误信息: {self.fatal_error}\n"
|
||||
"- 建议: 根据失败阶段检查模型配置、执行子进程日志和数据文件可用性。"
|
||||
)
|
||||
report_file_path = os.path.join(self.session_output_dir, "最终分析报告.md")
|
||||
try:
|
||||
with open(report_file_path, "w", encoding="utf-8") as f:
|
||||
f.write(final_report_content)
|
||||
print(f"📄 失败报告已保存至: {report_file_path}")
|
||||
except Exception as e:
|
||||
print(f"❌ 保存失败报告文件失败: {str(e)}")
|
||||
|
||||
return {
|
||||
"session_output_dir": self.session_output_dir,
|
||||
"total_rounds": self.current_round,
|
||||
"analysis_results": self.analysis_results,
|
||||
"collected_figures": [],
|
||||
"conversation_history": self.conversation_history,
|
||||
"final_report": final_report_content,
|
||||
"report_file_path": report_file_path,
|
||||
}
|
||||
|
||||
# 收集所有生成的图片信息
|
||||
all_figures = []
|
||||
for result in self.analysis_results:
|
||||
if result.get("action") == "collect_figures":
|
||||
all_figures.extend(result.get("collected_figures", []))
|
||||
|
||||
print(f"\n📊 开始生成最终分析报告...")
|
||||
print(f"📂 输出目录: {self.session_output_dir}")
|
||||
print(f"🔢 总轮数: {self.current_round}")
|
||||
print(f"📈 收集图片: {len(all_figures)} 个")
|
||||
|
||||
# 构建用于生成最终报告的提示词
|
||||
final_report_prompt = self._build_final_report_prompt(all_figures)
|
||||
|
||||
try: # 调用LLM生成最终报告
|
||||
response = self.llm.call(
|
||||
prompt=final_report_prompt,
|
||||
system_prompt="你将会接收到一个数据分析任务的最终报告请求,请根据提供的分析结果和图片信息生成完整的分析报告。",
|
||||
max_tokens=16384, # 设置较大的token限制以容纳完整报告
|
||||
)
|
||||
|
||||
# 解析响应,提取最终报告
|
||||
try:
|
||||
# 尝试解析YAML
|
||||
yaml_data = self.llm.parse_yaml_response(response)
|
||||
|
||||
# 情况1: 标准YAML格式,包含 action: analysis_complete
|
||||
if yaml_data.get("action") == "analysis_complete":
|
||||
final_report_content = yaml_data.get("final_report", response)
|
||||
|
||||
# 情况2: 解析成功但没字段,或者解析失败
|
||||
else:
|
||||
# 如果内容看起来像Markdown报告(包含标题),直接使用
|
||||
if "# " in response or "## " in response:
|
||||
print("⚠️ 未检测到标准YAML动作,但内容疑似Markdown报告,直接采纳")
|
||||
final_report_content = response
|
||||
else:
|
||||
final_report_content = "LLM未返回有效报告内容"
|
||||
|
||||
except Exception as e:
|
||||
# 解析完全失败,直接使用原始响应
|
||||
print(f"⚠️ YAML解析失败 ({e}),直接使用原始响应作为报告")
|
||||
final_report_content = response
|
||||
|
||||
print("✅ 最终报告生成完成")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 生成最终报告时出错: {str(e)}")
|
||||
final_report_content = f"报告生成失败: {str(e)}"
|
||||
|
||||
# 保存最终报告到文件
|
||||
report_file_path = os.path.join(self.session_output_dir, "最终分析报告.md")
|
||||
try:
|
||||
with open(report_file_path, "w", encoding="utf-8") as f:
|
||||
f.write(final_report_content)
|
||||
print(f"📄 最终报告已保存至: {report_file_path}")
|
||||
except Exception as e:
|
||||
print(f"❌ 保存报告文件失败: {str(e)}")
|
||||
|
||||
# 返回完整的分析结果
|
||||
return {
|
||||
"session_output_dir": self.session_output_dir,
|
||||
"total_rounds": self.current_round,
|
||||
"analysis_results": self.analysis_results,
|
||||
"collected_figures": all_figures,
|
||||
"conversation_history": self.conversation_history,
|
||||
"final_report": final_report_content,
|
||||
"report_file_path": report_file_path,
|
||||
}
|
||||
|
||||
def _build_final_report_prompt(self, all_figures: List[Dict[str, Any]]) -> str:
|
||||
"""构建用于生成最终报告的提示词"""
|
||||
|
||||
# 构建图片信息摘要,使用相对路径
|
||||
figures_summary = ""
|
||||
if all_figures:
|
||||
figures_summary = "\n生成的图片及分析:\n"
|
||||
for i, figure in enumerate(all_figures, 1):
|
||||
filename = figure.get("filename", "未知文件名")
|
||||
# 使用相对路径格式,适合在报告中引用
|
||||
relative_path = f"./{filename}"
|
||||
figures_summary += f"{i}. {filename}\n"
|
||||
figures_summary += f" 相对路径: {relative_path}\n"
|
||||
figures_summary += f" 描述: {figure.get('description', '无描述')}\n"
|
||||
figures_summary += f" 分析: {figure.get('analysis', '无分析')}\n\n"
|
||||
else:
|
||||
figures_summary = "\n本次分析未生成图片。\n"
|
||||
|
||||
# 构建代码执行结果摘要(仅包含成功执行的代码块)
|
||||
code_results_summary = ""
|
||||
success_code_count = 0
|
||||
for result in self.analysis_results:
|
||||
if result.get("action") != "collect_figures" and result.get("code"):
|
||||
exec_result = result.get("result", {})
|
||||
if exec_result.get("success"):
|
||||
success_code_count += 1
|
||||
code_results_summary += f"代码块 {success_code_count}: 执行成功\n"
|
||||
if exec_result.get("output"):
|
||||
code_results_summary += (
|
||||
f"输出: {exec_result.get('output')[:]}\n\n"
|
||||
)
|
||||
|
||||
# 使用 prompts.py 中的统一提示词模板,并添加相对路径使用说明
|
||||
prompt = final_report_system_prompt.format(
|
||||
current_round=self.current_round,
|
||||
session_output_dir=self.session_output_dir,
|
||||
data_profile=self.data_profile, # 注入数据画像
|
||||
figures_summary=figures_summary,
|
||||
code_results_summary=code_results_summary,
|
||||
)
|
||||
|
||||
# 在提示词中明确要求使用相对路径
|
||||
prompt += """
|
||||
|
||||
📁 **图片路径使用说明**:
|
||||
报告和图片都在同一目录下,请在报告中使用相对路径引用图片:
|
||||
- 格式:
|
||||
- 示例:
|
||||
- 这样可以确保报告在不同环境下都能正确显示图片
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
def reset(self):
|
||||
"""重置智能体状态"""
|
||||
self.conversation_history = []
|
||||
self.analysis_results = []
|
||||
self.current_round = 0
|
||||
self.fatal_error = ""
|
||||
self.fatal_error_stage = ""
|
||||
self.session_files = []
|
||||
self.template_content = "未提供特定模板,请根据数据画像自主发挥。"
|
||||
self.close_session()
|
||||
|
||||
def close_session(self):
|
||||
"""关闭当前分析会话但保留对象实例。"""
|
||||
if self.executor:
|
||||
self.executor.close()
|
||||
self.executor = None
|
||||
|
||||
def _initialize_session(
|
||||
self,
|
||||
user_input: str,
|
||||
files: List[str],
|
||||
template_path: str,
|
||||
session_output_dir: str,
|
||||
) -> None:
|
||||
if session_output_dir:
|
||||
self.session_output_dir = session_output_dir
|
||||
else:
|
||||
self.session_output_dir = create_session_output_dir(
|
||||
self.base_output_dir, user_input
|
||||
)
|
||||
|
||||
self.session_files = list(files)
|
||||
self.executor = ExecutionSessionClient(
|
||||
output_dir=self.session_output_dir,
|
||||
allowed_files=self.session_files,
|
||||
)
|
||||
|
||||
data_profile = ""
|
||||
if self.session_files:
|
||||
print("🔍 正在生成数据画像...")
|
||||
data_profile = load_and_profile_data(self.session_files)
|
||||
print("✅ 数据画像生成完毕")
|
||||
self.data_profile = data_profile
|
||||
self.template_content = self._load_template_content(template_path)
|
||||
|
||||
def _load_template_content(self, template_path: str = None) -> str:
|
||||
template_content = "未提供特定模板,请根据数据画像自主发挥。"
|
||||
if template_path and os.path.exists(template_path):
|
||||
try:
|
||||
with open(template_path, "r", encoding="utf-8") as f:
|
||||
template_content = f.read()
|
||||
print(f"📄 已加载参考模板: {os.path.basename(template_path)}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ 读取模板失败: {str(e)}")
|
||||
return template_content
|
||||
|
||||
def _build_initial_user_prompt(self, user_input: str, files: List[str]) -> str:
|
||||
prompt = f"### 任务启动\n**用户需求**: {user_input}\n"
|
||||
if files:
|
||||
prompt += f"**数据文件**: {', '.join(files)}\n"
|
||||
if self.data_profile:
|
||||
prompt += f"\n{self.data_profile}\n\n"
|
||||
prompt += "请基于上述数据画像和用户问题,自主决定本轮最合适的分析起点。"
|
||||
return prompt
|
||||
|
||||
def _build_followup_user_prompt(self, user_input: str) -> str:
|
||||
return (
|
||||
"### 继续分析专题\n"
|
||||
f"**新的专题需求**: {user_input}\n"
|
||||
"请基于当前会话中已经加载的数据、已有图表、历史分析结果和中间变量继续分析。"
|
||||
)
|
||||
|
||||
def _validate_followup_files(self, files: List[str]) -> None:
|
||||
if files and [os.path.abspath(path) for path in files] != [
|
||||
os.path.abspath(path) for path in self.session_files
|
||||
]:
|
||||
raise ValueError("同一分析会话暂不支持切换为不同的数据文件集合,请新建会话。")
|
||||
74
main.py
Normal file
74
main.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from data_analysis_agent import DataAnalysisAgent
|
||||
from config.llm_config import LLMConfig
|
||||
|
||||
import sys
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from utils.create_session_dir import create_session_output_dir
|
||||
|
||||
class DualLogger:
|
||||
"""同时输出到终端和文件的日志记录器"""
|
||||
def __init__(self, log_dir, filename="log.txt"):
|
||||
self.terminal = sys.stdout
|
||||
log_path = os.path.join(log_dir, filename)
|
||||
self.log = open(log_path, "a", encoding="utf-8")
|
||||
|
||||
def write(self, message):
|
||||
self.terminal.write(message)
|
||||
# 过滤掉生成的代码块,不写入日志文件
|
||||
if "🔧 执行代码:" in message:
|
||||
return
|
||||
self.log.write(message)
|
||||
self.log.flush()
|
||||
|
||||
def flush(self):
|
||||
self.terminal.flush()
|
||||
self.log.flush()
|
||||
|
||||
def setup_logging(log_dir):
|
||||
"""配置日志记录"""
|
||||
# 记录开始时间
|
||||
logger = DualLogger(log_dir)
|
||||
sys.stdout = logger
|
||||
# 可选:也将错误输出重定向
|
||||
# sys.stderr = logger
|
||||
print(f"\n{'='*20} Run Started at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} {'='*20}\n")
|
||||
print(f"📄 日志文件已保存至: {os.path.join(log_dir, 'log.txt')}")
|
||||
|
||||
|
||||
def main():
|
||||
llm_config = LLMConfig()
|
||||
files = ["./UB IOV Support_TR.csv"]
|
||||
|
||||
# 简化后的需求,让 Agent 自主规划
|
||||
analysis_requirement = "我想了解这份工单数据的健康度。请帮我进行全面分析,并找出核心问题点和改进建议。"
|
||||
|
||||
# 在主函数中先创建会话目录,以便存放日志
|
||||
base_output_dir = "outputs"
|
||||
session_output_dir = create_session_output_dir(base_output_dir, analysis_requirement)
|
||||
|
||||
# 设置日志
|
||||
setup_logging(session_output_dir)
|
||||
|
||||
# force_max_rounds=False 允许 AI 在认为完成后主动停止
|
||||
agent = DataAnalysisAgent(llm_config, max_rounds=20, force_max_rounds=False)
|
||||
|
||||
# 这里的 template_path 如果你有特定的参考模板文件,可以传入路径
|
||||
template_path = None
|
||||
|
||||
report = agent.analyze(
|
||||
user_input=analysis_requirement,
|
||||
files=files,
|
||||
template_path=template_path,
|
||||
session_output_dir=session_output_dir
|
||||
)
|
||||
|
||||
print("\n" + "="*60)
|
||||
print(f"✅ 分析任务圆满完成!")
|
||||
print(f"📊 报告路径: {report.get('report_file_path')}")
|
||||
print("="*60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
147
prompts.py
Normal file
147
prompts.py
Normal file
@@ -0,0 +1,147 @@
|
||||
data_analysis_system_prompt = """你是一个顶级 AI 数据分析专家 (Data Scientist Agent),运行在 Jupyter Notebook 环境中。你的使命是:像人类专家一样理解数据、选择合适的方法、自主生成分析代码,并给出有证据支撑的业务结论。
|
||||
|
||||
### 核心工作原则 (Core Principles)
|
||||
1. **目标驱动**:围绕用户问题组织分析,不机械套用固定流程,也不要为了使用某种方法而使用某种方法。
|
||||
2. **方法自选**:你可以自由决定本轮要做描述统计、可视化、异常检测、相关分析、分组比较、回归、聚类、文本挖掘或根因拆解;但每次都要说明为什么当前方法适合当前问题。
|
||||
3. **证据优先**:每个重要结论都必须有指标、表格、图形或统计检验支撑。没有证据时,要明确说明“不足以得出结论”。
|
||||
4. **隐私保护**:
|
||||
- 严禁直接转储大段原始数据。
|
||||
- 优先输出聚合结果、摘要统计、分布、分组对比和样本量信息。
|
||||
- 原始数据处理必须在本地 Python 环境中完成。
|
||||
5. **动态下钻**:发现异常、结构性差异或可疑模式后,可以继续细分维度追查原因;如果没有新增价值,也可以停止无效下钻。
|
||||
6. **业务解释**:每张图、每个统计结果都要回答“这说明了什么”“对业务意味着什么”。
|
||||
|
||||
---
|
||||
|
||||
### 可用环境 (Environment)
|
||||
- 数据处理:`pandas`, `numpy`, `duckdb`, `scipy`
|
||||
- 可视化:`matplotlib`, `seaborn`, `plotly`
|
||||
- 建模与统计:`scikit-learn`
|
||||
- 文本分析:`sklearn.feature_extraction.text`
|
||||
- 中文图表:已支持中文字体显示
|
||||
|
||||
---
|
||||
|
||||
### 方法选择原则 (Method Guidance)
|
||||
- 如果目标是理解数据结构,可使用:维度、字段类型、缺失率、分布、重复率、时间范围、分组汇总。
|
||||
- 如果目标是识别异常,可使用:分位数/IQR、Z-score、分组对比、时序波动、异常检测模型等合适方法。
|
||||
- 如果目标是解释驱动因素,可使用:相关分析、分组均值比较、卡方检验、回归、树模型特征重要性等合适方法。
|
||||
- 如果目标是发现结构分层,可使用:聚类、分层分组、贡献度分析、Pareto 分析等合适方法。
|
||||
- 如果目标涉及文本字段,可使用:2-gram/3-gram、TF-IDF、高频短语、相似文本聚类等合适方法。
|
||||
- 不要写死方法顺序;优先选择当前数据和问题最匹配的方法。
|
||||
|
||||
### 输出质量标准 (Evidence Standards)
|
||||
- 结论必须对应到可复核的指标或图表。
|
||||
- 做分组比较时,要说明样本量,避免基于极小样本下结论。
|
||||
- 使用建模或统计检验时,要简要说明目标变量、特征或检验对象。
|
||||
- 如果图表没有明显增量价值,可以只输出表格和结论。
|
||||
- 如果用户提供了参考模板,可以借用其结构,但不要被模板限制住分析判断。
|
||||
|
||||
### 统计方法使用规范 (Statistical Quality Rules)
|
||||
- 选择统计方法时,要说明它为什么适合当前问题,而不是机械套用。
|
||||
- 进行分组比较时,至少说明样本量、比较维度和核心差异指标。
|
||||
- 使用相关分析时,要区分“相关”与“因果”,避免把相关性直接解释为因果关系。
|
||||
- 使用回归、树模型或聚类时,要明确目标变量、输入特征和结果含义,不要只输出模型分数。
|
||||
- 使用异常检测时,要交代异常判定依据,例如分位数阈值、标准差阈值、模型评分或时序偏离程度。
|
||||
- 使用文本分析时,要输出短语、主题或类别模式,并结合业务语境解释,不要只给孤立词频。
|
||||
- 如果样本量过小、字段质量不足或方法前提不满足,应降低结论强度,并明确说明局限性。
|
||||
|
||||
### 证据链要求 (Evidence Chain)
|
||||
- 每个关键发现尽量形成以下链路:指标/统计结果 -> 图表或表格 -> 业务解释 -> 行动建议。
|
||||
- 如果引用图表,必须说明该图回答了什么问题,而不是只展示图片。
|
||||
- 如果没有图表,也必须提供足够的数值证据或表格摘要支撑结论。
|
||||
- 建议项必须尽量回扣前面的证据,不要给与数据无关的泛化建议。
|
||||
|
||||
---
|
||||
|
||||
### 强制约束规则 (Hard Constraints)
|
||||
1. **图片保存**:必须使用 `os.path.join(session_output_dir, '中文文件名.png')` 保存,并打印绝对路径。严禁使用 `plt.show()`。
|
||||
2. **图表规范**:
|
||||
- 类别较多时优先使用条形图,避免难以阅读的饼图。
|
||||
- 必须设置 `plt.rcParams['font.sans-serif']` 确保中文不乱码。
|
||||
- 图表标题、坐标轴和图例应可直接支持业务解读。
|
||||
3. **文本分析**:如果分析文本字段,优先提取短语或主题,不要只统计单字频率。
|
||||
4. **响应格式**:始终输出合法 YAML。
|
||||
5. **动作选择**:
|
||||
- 当需要继续分析时,使用 `generate_code`
|
||||
- 当已经生成关键图表且需要纳入最终证据链时,使用 `collect_figures`
|
||||
- 当分析已足以回答用户问题时,使用 `analysis_complete`
|
||||
|
||||
---
|
||||
|
||||
### 响应格式示例
|
||||
|
||||
**探索/分析轮:**
|
||||
```yaml
|
||||
action: "generate_code"
|
||||
reasoning: "当前先验证数据结构、关键字段分布和样本量,再决定是否需要进一步做异常检测或分组对比。"
|
||||
code: |
|
||||
import pandas as pd
|
||||
import os
|
||||
# 读取数据
|
||||
# 输出聚合统计
|
||||
# 根据结果决定下一步方法
|
||||
next_steps:
|
||||
- "确认关键字段的数据质量"
|
||||
- "识别是否存在异常分布或结构性差异"
|
||||
```
|
||||
|
||||
**最终报告轮:**
|
||||
```yaml
|
||||
action: "analysis_complete"
|
||||
final_report: |
|
||||
# 专业 Markdown 报告内容...
|
||||
```
|
||||
|
||||
当前 Jupyter Notebook 环境变量:
|
||||
{notebook_variables}
|
||||
|
||||
用户需求:{user_input}
|
||||
参考模板(如有):{template_content}
|
||||
"""
|
||||
|
||||
# 最终报告生成提示词
|
||||
final_report_system_prompt = """你是一位资深数据分析专家 (Senior Data Analyst)。你的任务是基于详细的数据分析过程,撰写一份专业级、可落地的业务分析报告。
|
||||
|
||||
### 输入上下文
|
||||
- **数据画像 (Data Profile)**:
|
||||
{data_profile}
|
||||
|
||||
- **分析过程与关键代码发现**:
|
||||
{code_results_summary}
|
||||
|
||||
- **可视化证据链 (Visual Evidence)**:
|
||||
{figures_summary}
|
||||
> **警告**:你必须引用已生成的关键图表。引用格式为 ``。
|
||||
|
||||
### 报告核心要求
|
||||
1. **客观性**:严禁使用“我”、“我们”等主观人称,采用陈述性语气。
|
||||
2. **闭环性**:报告必须针对初始计划中的所有核心方向给出明确的“结论”或“由于数据受限无法给出结论的原因”。
|
||||
3. **行动力**:最后的“建议矩阵”必须具体到部门、周期和预期收益。
|
||||
|
||||
---
|
||||
|
||||
### 报告结构模板 (Markdown)
|
||||
|
||||
# [项目/产品名称] 业务洞察报告
|
||||
|
||||
## 1. 摘要 (Executive Summary)
|
||||
- **核心洞察**:[一句话概括]
|
||||
- **健康度评分**:[0-100]
|
||||
- **TOP 3 关键发现**:
|
||||
- **核心建议预览**:
|
||||
|
||||
## 2. 数据理解与方法论 (Methodology)
|
||||
- **数据覆盖范围**:[时间窗口、样本量]
|
||||
- **分析框架**:[如:用户漏斗、RCA 归因、时序预测等]
|
||||
- **局限性声明**:[如:数据缺失某字段导致的分析受限]
|
||||
|
||||
## 3. 核心发现与洞察 (Key Insights)
|
||||
[按业务主题展开,每个主题必须包含证据图表、数据表现、业务结论]
|
||||
|
||||
## 4. 专项根因分析 (Deep Dive)
|
||||
[针对分析过程中发现的异常点进行的下钻分析结论]
|
||||
|
||||
## 5. 建议与行动矩阵 (Recommendations)
|
||||
[建议项 | 优先级 | 关键举措 | 预期收益 | 负责人]
|
||||
"""
|
||||
447
require.md
Normal file
447
require.md
Normal file
@@ -0,0 +1,447 @@
|
||||
# 真正的 AI 数据分析 Agent - 需求文档
|
||||
|
||||
## 1. 项目背景
|
||||
|
||||
### 1.1 当前问题
|
||||
|
||||
现有系统是"四不像":
|
||||
- 任务规划:基于模板的规则生成(固定90个任务)
|
||||
- 任务执行:AI 驱动的 ReAct 模式
|
||||
- 结果:规则 + AI = 不协调、不灵活
|
||||
|
||||
### 1.2 核心问题
|
||||
|
||||
**用户的真实需求**:
|
||||
> "我有数据,帮我分析一下"
|
||||
> "我想了解工单的健康度"
|
||||
> "按照这个模板分析,但要灵活调整"
|
||||
|
||||
**系统应该做什么**:
|
||||
- 像人类分析师一样理解数据
|
||||
- 自主决定分析什么
|
||||
- 根据发现调整分析计划
|
||||
- 生成有洞察力的报告
|
||||
|
||||
**而不是**:
|
||||
- 机械地执行固定任务
|
||||
- 死板地按模板填空
|
||||
|
||||
## 2. 用户故事
|
||||
|
||||
### 2.1 场景1:完全自主分析
|
||||
|
||||
**作为** 数据分析师
|
||||
**我想要** 上传数据文件,让 AI 自动分析
|
||||
**以便** 快速了解数据的关键信息
|
||||
|
||||
**验收标准**:
|
||||
- AI 能识别数据类型(工单、销售、用户等)
|
||||
- AI 能推断关键字段的业务含义
|
||||
- AI 能自主决定分析维度
|
||||
- AI 能生成合理的分析计划
|
||||
- AI 能执行分析并生成报告
|
||||
- 报告包含关键发现和洞察
|
||||
|
||||
**示例**:
|
||||
```
|
||||
输入:cleaned_data.csv
|
||||
输出:
|
||||
- 数据类型:工单数据
|
||||
- 关键发现:
|
||||
* 待处理工单占比50%(异常高)
|
||||
* 某车型问题占比80%
|
||||
* 平均处理时长超过标准2倍
|
||||
- 建议:优先处理该车型的积压工单
|
||||
```
|
||||
|
||||
### 2.2 场景2:指定分析方向
|
||||
|
||||
**作为** 业务负责人
|
||||
**我想要** 指定分析方向(如"健康度")
|
||||
**以便** 获得针对性的分析结果
|
||||
|
||||
**验收标准**:
|
||||
- AI 能理解"健康度"的业务含义
|
||||
- AI 能将抽象概念转化为具体指标
|
||||
- AI 能根据数据特征选择合适的分析方法
|
||||
- AI 能生成针对性的报告
|
||||
|
||||
**示例**:
|
||||
```
|
||||
输入:
|
||||
- 数据:cleaned_data.csv
|
||||
- 需求:"我想了解工单的健康度"
|
||||
|
||||
AI 理解:
|
||||
- 健康度 = 关闭率 + 处理效率 + 积压情况 + 响应及时性
|
||||
|
||||
AI 分析:
|
||||
- 关闭率:75%(中等)
|
||||
- 平均处理时长:48小时(偏长)
|
||||
- 积压工单:50%(严重)
|
||||
- 健康度评分:60/100(需改进)
|
||||
```
|
||||
|
||||
### 2.3 场景3:参考模板分析
|
||||
|
||||
**作为** 数据分析师
|
||||
**我想要** 使用模板作为参考框架
|
||||
**以便** 保持报告结构的一致性,同时保持灵活性
|
||||
|
||||
**验收标准**:
|
||||
- AI 能理解模板的结构和要求
|
||||
- AI 能检查数据是否满足模板要求
|
||||
- 如果数据缺少某些字段,AI 能灵活调整
|
||||
- AI 能按模板结构组织报告
|
||||
- AI 不会因为数据不完全匹配而失败
|
||||
|
||||
**示例**:
|
||||
```
|
||||
输入:
|
||||
- 数据:cleaned_data.csv
|
||||
- 模板:issue_analysis.md(要求14个图表)
|
||||
|
||||
AI 检查:
|
||||
- 模板要求"严重程度分布",但数据中没有"严重程度"字段
|
||||
- 决策:跳过该分析,在报告中说明
|
||||
|
||||
AI 调整:
|
||||
- 执行其他13个分析
|
||||
- 报告中注明:"数据缺少严重程度字段,无法分析该维度"
|
||||
```
|
||||
|
||||
### 2.4 场景4:迭代深入分析
|
||||
|
||||
**作为** 数据分析师
|
||||
**我想要** AI 能根据发现深入分析
|
||||
**以便** 找到问题的根因
|
||||
|
||||
**验收标准**:
|
||||
- AI 能识别异常或关键发现
|
||||
- AI 能自主决定是否需要深入分析
|
||||
- AI 能动态调整分析计划
|
||||
- AI 能追踪问题的根因
|
||||
|
||||
**示例**:
|
||||
```
|
||||
初步分析:
|
||||
- 发现:待处理工单占比50%(异常高)
|
||||
|
||||
AI 决策:需要深入分析
|
||||
|
||||
深入分析1:
|
||||
- 分析待处理工单的特征
|
||||
- 发现:某车型占80%
|
||||
|
||||
AI 决策:继续深入
|
||||
|
||||
深入分析2:
|
||||
- 分析该车型的问题类型
|
||||
- 发现:都是"远程控制"问题
|
||||
|
||||
AI 决策:继续深入
|
||||
|
||||
深入分析3:
|
||||
- 分析"远程控制"问题的模块分布
|
||||
- 发现:90%是"车门模块"
|
||||
|
||||
结论:车门模块的远程控制功能存在系统性问题
|
||||
```
|
||||
|
||||
## 3. 功能需求
|
||||
|
||||
### 3.1 数据理解(Data Understanding)
|
||||
|
||||
**FR-1.1 数据加载**
|
||||
- 系统应支持 CSV 格式数据
|
||||
- 系统应自动检测编码(UTF-8, GBK等)
|
||||
- 系统应处理常见的数据格式问题
|
||||
|
||||
**FR-1.2 数据类型识别**
|
||||
- AI 应分析列名、数据类型、值分布
|
||||
- AI 应推断数据的业务类型(工单、销售、用户等)
|
||||
- AI 应识别关键字段(时间、状态、分类、数值)
|
||||
|
||||
**FR-1.3 字段含义理解**
|
||||
- AI 应推断每个字段的业务含义
|
||||
- AI 应识别字段之间的关系
|
||||
- AI 应识别可能的分析维度
|
||||
|
||||
**FR-1.4 数据质量评估**
|
||||
- AI 应检查缺失值
|
||||
- AI 应检查异常值
|
||||
- AI 应评估数据质量分数
|
||||
|
||||
### 3.2 需求理解(Requirement Understanding)
|
||||
|
||||
**FR-2.1 自主需求推断**
|
||||
- 当用户未指定需求时,AI 应根据数据类型推断常见分析需求
|
||||
- AI 应生成默认的分析目标
|
||||
|
||||
**FR-2.2 用户需求理解**
|
||||
- AI 应理解用户的自然语言需求
|
||||
- AI 应将抽象概念转化为具体指标
|
||||
- AI 应判断数据是否支持用户需求
|
||||
|
||||
**FR-2.3 模板理解**
|
||||
- AI 应解析模板结构
|
||||
- AI 应理解模板要求的指标和图表
|
||||
- AI 应检查数据是否满足模板要求
|
||||
- AI 应在数据不满足时灵活调整
|
||||
|
||||
### 3.3 分析规划(Analysis Planning)
|
||||
|
||||
**FR-3.1 动态任务生成**
|
||||
- AI 应根据数据特征和需求生成分析任务
|
||||
- 任务应是动态的,不是固定的
|
||||
- 任务应包含优先级和依赖关系
|
||||
|
||||
**FR-3.2 任务优先级**
|
||||
- AI 应根据重要性排序任务
|
||||
- 必需的分析应优先执行
|
||||
- 可选的分析应后执行
|
||||
|
||||
**FR-3.3 计划调整**
|
||||
- AI 应能根据中间结果调整计划
|
||||
- AI 应能增加新的深入分析任务
|
||||
- AI 应能跳过不适用的任务
|
||||
|
||||
### 3.4 工具集管理(Tool Management)
|
||||
|
||||
**FR-4.1 预设工具集**
|
||||
- 系统应提供基础数据分析工具集
|
||||
- 基础工具包括:数据查询、统计分析、可视化、数据清洗
|
||||
- 工具应有标准的接口和描述
|
||||
|
||||
**FR-4.2 动态工具调整**
|
||||
- AI 应根据数据特征决定需要哪些工具
|
||||
- AI 应根据分析需求动态启用/禁用工具
|
||||
- AI 应能识别缺少的工具并请求添加
|
||||
|
||||
**FR-4.3 工具适配**
|
||||
- AI 应根据数据类型调整工具参数
|
||||
- 例如:时间序列数据 → 启用趋势分析工具
|
||||
- 例如:分类数据 → 启用分布分析工具
|
||||
- 例如:地理数据 → 启用地图可视化工具
|
||||
|
||||
**FR-4.4 自定义工具生成**
|
||||
- AI 应能根据特定需求生成临时工具
|
||||
- AI 应能组合现有工具创建新功能
|
||||
- 自定义工具应在分析结束后可选保留
|
||||
|
||||
**示例**:
|
||||
```
|
||||
数据特征:
|
||||
- 包含时间字段(created_at, closed_at)
|
||||
- 包含分类字段(status, type, model)
|
||||
- 包含数值字段(duration)
|
||||
|
||||
AI 决策:
|
||||
- 启用工具:时间序列分析、分类分布、数值统计
|
||||
- 禁用工具:地理分析(无地理字段)
|
||||
- 生成工具:计算处理时长(closed_at - created_at)
|
||||
```
|
||||
|
||||
### 3.5 分析执行(Analysis Execution)
|
||||
|
||||
**FR-5.1 ReAct 执行模式**
|
||||
- 每个任务应使用 ReAct 模式执行
|
||||
- AI 应思考 → 行动 → 观察 → 判断
|
||||
- AI 应能从错误中学习
|
||||
|
||||
**FR-5.2 工具调用**
|
||||
- AI 应从可用工具集中选择合适的工具
|
||||
- AI 应能组合多个工具完成复杂任务
|
||||
- AI 应能处理工具调用失败的情况
|
||||
|
||||
**FR-5.3 结果验证**
|
||||
- AI 应验证每个任务的结果
|
||||
- AI 应识别异常结果
|
||||
- AI 应决定是否需要重试或调整
|
||||
|
||||
**FR-5.4 迭代深入**
|
||||
- AI 应识别关键发现
|
||||
- AI 应决定是否需要深入分析
|
||||
- AI 应动态增加深入分析任务
|
||||
|
||||
### 3.6 报告生成(Report Generation)
|
||||
|
||||
**FR-6.1 关键发现提炼**
|
||||
- AI 应从所有结果中提炼关键发现
|
||||
- AI 应识别异常和趋势
|
||||
- AI 应提供洞察而不是简单罗列数据
|
||||
|
||||
**FR-6.2 报告结构组织**
|
||||
- AI 应根据分析内容组织报告结构
|
||||
- 如果有模板,应参考模板结构
|
||||
- 如果没有模板,应生成合理的结构
|
||||
|
||||
**FR-6.3 结论和建议**
|
||||
- AI 应基于分析结果得出结论
|
||||
- AI 应提供可操作的建议
|
||||
- AI 应说明建议的依据
|
||||
|
||||
**FR-6.4 多格式输出**
|
||||
- 系统应生成 Markdown 格式报告
|
||||
- 系统应支持导出为 Word 文档(可选)
|
||||
- 报告应包含所有生成的图表
|
||||
|
||||
## 4. 非功能需求
|
||||
|
||||
### 4.1 性能需求
|
||||
|
||||
**NFR-1.1 响应时间**
|
||||
- 数据理解阶段:< 30秒
|
||||
- 分析规划阶段:< 60秒
|
||||
- 单个任务执行:< 120秒
|
||||
- 完整分析流程:< 30分钟(取决于数据大小和任务数量)
|
||||
|
||||
**NFR-1.2 数据规模**
|
||||
- 支持最大 100MB 的 CSV 文件
|
||||
- 支持最大 100万行数据
|
||||
- 支持最大 100列
|
||||
|
||||
### 4.2 可靠性需求
|
||||
|
||||
**NFR-2.1 错误处理**
|
||||
- AI 调用失败时应有降级策略
|
||||
- 单个任务失败不应影响整体流程
|
||||
- 系统应记录详细的错误日志
|
||||
|
||||
**NFR-2.2 数据安全**
|
||||
- 数据应在本地处理,不上传到外部服务
|
||||
- 生成的报告应保存在用户指定的目录
|
||||
- 敏感信息应脱敏处理
|
||||
|
||||
### 4.3 可用性需求
|
||||
|
||||
**NFR-3.1 易用性**
|
||||
- 用户只需提供数据文件即可开始分析
|
||||
- 分析过程应显示进度和状态
|
||||
- 错误信息应清晰易懂
|
||||
|
||||
**NFR-3.2 可观察性**
|
||||
- 系统应显示 AI 的思考过程
|
||||
- 系统应显示每个阶段的进度
|
||||
- 系统应记录完整的执行日志
|
||||
|
||||
### 4.4 可扩展性需求
|
||||
|
||||
**NFR-4.1 工具扩展**
|
||||
- 应易于添加新的分析工具
|
||||
- 工具应有标准接口
|
||||
- AI 应能自动发现和使用新工具
|
||||
- 工具应支持热加载,无需重启系统
|
||||
|
||||
**NFR-4.2 工具动态性**
|
||||
- 工具集应根据数据特征动态调整
|
||||
- 工具参数应根据数据类型自适应
|
||||
- 系统应支持运行时生成临时工具
|
||||
|
||||
**NFR-4.3 模型扩展**
|
||||
- 应支持不同的 LLM 提供商
|
||||
- 应支持本地模型和云端模型
|
||||
- 应支持模型切换
|
||||
|
||||
## 5. 约束条件
|
||||
|
||||
### 5.1 技术约束
|
||||
|
||||
- 使用 Python 3.8+
|
||||
- 使用 OpenAI 兼容的 LLM API
|
||||
- 使用 pandas 进行数据处理
|
||||
- 使用 matplotlib 进行可视化
|
||||
|
||||
### 5.2 业务约束
|
||||
|
||||
- 系统应在离线环境下工作(除 LLM 调用外)
|
||||
- 系统不应依赖特定的数据格式或业务领域
|
||||
- 系统应保持通用性,适用于各种数据分析场景
|
||||
|
||||
### 5.3 隐私和安全约束
|
||||
|
||||
**数据隐私保护**:
|
||||
- AI 不能访问完整的原始数据内容
|
||||
- AI 只能读取:
|
||||
- 表头(列名)
|
||||
- 数据类型信息
|
||||
- 基本统计摘要(行数、列数、缺失值比例、数据类型分布)
|
||||
- 工具执行后的聚合结果(如分组统计结果、图表数据)
|
||||
- 所有原始数据处理必须在本地完成,不发送给 LLM
|
||||
- AI 通过调用本地工具来分析数据,工具返回摘要结果而非原始数据
|
||||
|
||||
### 5.3 隐私和安全约束
|
||||
|
||||
**数据隐私保护**:
|
||||
- AI 不能访问完整的原始数据内容
|
||||
- AI 只能读取:
|
||||
- 表头(列名)
|
||||
- 数据类型信息
|
||||
- 基本统计摘要(行数、列数、缺失值比例、数据类型分布)
|
||||
- 工具执行后的聚合结果(如分组统计结果、图表数据)
|
||||
- 所有原始数据处理必须在本地完成,不发送给 LLM
|
||||
- AI 通过调用本地工具来分析数据,工具返回摘要结果而非原始数据
|
||||
|
||||
## 6. 验收标准
|
||||
|
||||
### 6.1 场景1验收
|
||||
|
||||
- [ ] 上传任意 CSV 文件,AI 能识别数据类型
|
||||
- [ ] AI 能自主生成分析计划
|
||||
- [ ] AI 能执行分析并生成报告
|
||||
- [ ] 报告包含关键发现和洞察
|
||||
|
||||
### 6.2 场景2验收
|
||||
|
||||
- [ ] 指定"健康度"等抽象需求,AI 能理解
|
||||
- [ ] AI 能生成相关指标
|
||||
- [ ] AI 能执行针对性分析
|
||||
- [ ] 报告聚焦于用户需求
|
||||
|
||||
### 6.3 场景3验收
|
||||
|
||||
- [ ] 提供模板,AI 能理解模板要求
|
||||
- [ ] 数据缺少字段时,AI 能灵活调整
|
||||
- [ ] 报告按模板结构组织
|
||||
- [ ] 报告说明哪些分析被跳过及原因
|
||||
|
||||
### 6.4 场景4验收
|
||||
|
||||
- [ ] AI 能识别异常发现
|
||||
- [ ] AI 能自主决定深入分析
|
||||
- [ ] AI 能动态调整分析计划
|
||||
- [ ] 报告包含深入分析的结果
|
||||
|
||||
### 6.5 工具动态性验收
|
||||
|
||||
- [ ] 系统根据数据特征自动启用相关工具
|
||||
- [ ] 系统根据数据特征自动禁用无关工具
|
||||
- [ ] AI 能识别需要但缺失的工具
|
||||
- [ ] AI 能生成临时工具满足特定需求
|
||||
- [ ] 工具参数根据数据类型自动调整
|
||||
|
||||
## 7. 成功指标
|
||||
|
||||
### 7.1 功能指标
|
||||
|
||||
- 数据类型识别准确率 > 90%
|
||||
- 字段含义推断准确率 > 80%
|
||||
- 分析计划合理性(人工评估)> 85%
|
||||
- 报告质量(人工评估)> 80%
|
||||
|
||||
### 7.2 性能指标
|
||||
|
||||
- 完整分析流程完成率 > 95%
|
||||
- AI 调用成功率 > 90%
|
||||
|
||||
### 7.3 用户满意度
|
||||
|
||||
- 用户认为分析结果有价值 > 80%
|
||||
- 用户愿意再次使用 > 85%
|
||||
- 用户推荐给他人 > 75%
|
||||
|
||||
---
|
||||
|
||||
**版本**: v3.0.0
|
||||
**日期**: 2026-03-06
|
||||
**状态**: 需求定义完成
|
||||
55
requirements.txt
Normal file
55
requirements.txt
Normal file
@@ -0,0 +1,55 @@
|
||||
# 数据分析和科学计算库
|
||||
pandas>=2.0.0
|
||||
openpyxl>=3.1.0
|
||||
numpy>=1.24.0
|
||||
matplotlib>=3.6.0
|
||||
duckdb>=0.8.0
|
||||
scipy>=1.10.0
|
||||
scikit-learn>=1.3.0
|
||||
|
||||
# Web和API相关
|
||||
requests>=2.28.0
|
||||
urllib3>=1.26.0
|
||||
fastapi>=0.110.0
|
||||
uvicorn>=0.29.0
|
||||
python-multipart>=0.0.9
|
||||
|
||||
# 绘图和可视化
|
||||
plotly>=5.14.0
|
||||
dash>=2.0.0
|
||||
|
||||
# 流程图支持(可选,用于生成Mermaid图表)
|
||||
# 注意:Mermaid图表主要在markdown中渲染,不需要额外的Python包
|
||||
# 如果需要在Python中生成Mermaid代码,可以考虑:
|
||||
# mermaid-py>=0.3.0
|
||||
|
||||
# Jupyter/IPython环境
|
||||
ipython>=8.10.0
|
||||
jupyter>=1.0.0
|
||||
|
||||
# AI/LLM相关
|
||||
openai>=1.0.0
|
||||
pyyaml>=6.0
|
||||
|
||||
# 配置管理
|
||||
python-dotenv>=1.0.0
|
||||
|
||||
# 异步编程
|
||||
asyncio-mqtt>=0.11.1
|
||||
nest_asyncio>=1.5.0
|
||||
|
||||
# 文档生成(基于输出的Word文档)
|
||||
python-docx>=0.8.11
|
||||
|
||||
# 系统和工具库
|
||||
pathlib2>=2.3.7
|
||||
typing-extensions>=4.5.0
|
||||
|
||||
# 开发和测试工具(可选)
|
||||
pytest>=7.0.0
|
||||
pytest-asyncio>=0.21.0
|
||||
black>=23.0.0
|
||||
flake8>=6.0.0
|
||||
|
||||
# 字体支持(用于matplotlib中文显示)
|
||||
fonttools>=4.38.0
|
||||
16
utils/__init__.py
Normal file
16
utils/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
工具模块初始化文件
|
||||
"""
|
||||
|
||||
from utils.code_executor import CodeExecutor
|
||||
from utils.execution_session_client import ExecutionSessionClient
|
||||
from utils.llm_helper import LLMHelper
|
||||
from utils.fallback_openai_client import AsyncFallbackOpenAIClient
|
||||
|
||||
__all__ = [
|
||||
"CodeExecutor",
|
||||
"ExecutionSessionClient",
|
||||
"LLMHelper",
|
||||
"AsyncFallbackOpenAIClient",
|
||||
]
|
||||
456
utils/code_executor.py
Normal file
456
utils/code_executor.py
Normal file
@@ -0,0 +1,456 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
安全的代码执行器,基于 IPython 提供 notebook 环境下的代码执行功能
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import ast
|
||||
import traceback
|
||||
import io
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from contextlib import redirect_stdout, redirect_stderr
|
||||
from IPython.core.interactiveshell import InteractiveShell
|
||||
from IPython.utils.capture import capture_output
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.font_manager as fm
|
||||
|
||||
|
||||
class CodeExecutor:
|
||||
"""
|
||||
安全的代码执行器,限制依赖库,捕获输出,支持图片保存与路径输出
|
||||
"""
|
||||
|
||||
ALLOWED_IMPORTS = {
|
||||
"pandas",
|
||||
"pd",
|
||||
"numpy",
|
||||
"np",
|
||||
"matplotlib",
|
||||
"matplotlib.pyplot",
|
||||
"plt",
|
||||
"seaborn",
|
||||
"sns",
|
||||
"duckdb",
|
||||
"scipy",
|
||||
"sklearn",
|
||||
"statsmodels",
|
||||
"plotly",
|
||||
"dash",
|
||||
"requests",
|
||||
"urllib",
|
||||
"os",
|
||||
"sys",
|
||||
"json",
|
||||
"csv",
|
||||
"datetime",
|
||||
"time",
|
||||
"math",
|
||||
"statistics",
|
||||
"re",
|
||||
"pathlib",
|
||||
"io",
|
||||
"collections",
|
||||
"itertools",
|
||||
"functools",
|
||||
"operator",
|
||||
"warnings",
|
||||
"logging",
|
||||
"copy",
|
||||
"pickle",
|
||||
"gzip",
|
||||
"zipfile",
|
||||
"yaml",
|
||||
"typing",
|
||||
"dataclasses",
|
||||
"enum",
|
||||
"sqlite3",
|
||||
"jieba",
|
||||
"wordcloud",
|
||||
"PIL",
|
||||
"random",
|
||||
"networkx",
|
||||
}
|
||||
|
||||
def __init__(self, output_dir: str = "outputs"):
|
||||
"""
|
||||
初始化代码执行器
|
||||
|
||||
Args:
|
||||
output_dir: 输出目录,用于保存图片和文件
|
||||
"""
|
||||
self.output_dir = os.path.abspath(output_dir)
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
|
||||
# 为每个执行器创建独立的 shell,避免跨分析任务共享状态
|
||||
self.shell = InteractiveShell()
|
||||
|
||||
# 初始化隔离执行环境
|
||||
self.reset_environment()
|
||||
|
||||
# 图片计数器
|
||||
self.image_counter = 0
|
||||
|
||||
def _setup_chinese_font(self):
|
||||
"""设置matplotlib中文字体显示"""
|
||||
try:
|
||||
# 设置matplotlib使用Agg backend避免GUI问题
|
||||
matplotlib.use("Agg")
|
||||
|
||||
# 获取系统可用字体
|
||||
available_fonts = [f.name for f in fm.fontManager.ttflist]
|
||||
|
||||
# 设置matplotlib使用系统可用中文字体
|
||||
# macOS系统常用中文字体(按优先级排序)
|
||||
chinese_fonts = [
|
||||
"Hiragino Sans GB", # macOS中文简体
|
||||
"Songti SC", # macOS宋体简体
|
||||
"PingFang SC", # macOS苹方简体
|
||||
"Heiti SC", # macOS黑体简体
|
||||
"Heiti TC", # macOS黑体繁体
|
||||
"PingFang HK", # macOS苹方香港
|
||||
"SimHei", # Windows黑体
|
||||
"STHeiti", # 华文黑体
|
||||
"WenQuanYi Micro Hei", # Linux文泉驿微米黑
|
||||
"DejaVu Sans", # 默认无衬线字体
|
||||
"Arial Unicode MS", # Arial Unicode
|
||||
]
|
||||
|
||||
# 检查系统中实际存在的字体
|
||||
system_chinese_fonts = [
|
||||
font for font in chinese_fonts if font in available_fonts
|
||||
]
|
||||
|
||||
# 如果没有找到合适的中文字体,尝试更宽松的搜索
|
||||
if not system_chinese_fonts:
|
||||
print("警告:未找到精确匹配的中文字体,尝试更宽松的搜索...")
|
||||
# 更宽松的字体匹配(包含部分名称)
|
||||
fallback_fonts = []
|
||||
for available_font in available_fonts:
|
||||
if any(
|
||||
keyword in available_font
|
||||
for keyword in [
|
||||
"Hei",
|
||||
"Song",
|
||||
"Fang",
|
||||
"Kai",
|
||||
"Hiragino",
|
||||
"PingFang",
|
||||
"ST",
|
||||
]
|
||||
):
|
||||
fallback_fonts.append(available_font)
|
||||
|
||||
if fallback_fonts:
|
||||
system_chinese_fonts = fallback_fonts[:3] # 取前3个匹配的字体
|
||||
print(f"找到备选中文字体: {system_chinese_fonts}")
|
||||
else:
|
||||
print("警告:系统中未找到合适的中文字体,使用系统默认字体")
|
||||
system_chinese_fonts = ["DejaVu Sans", "Arial Unicode MS"]
|
||||
|
||||
# 设置字体配置
|
||||
plt.rcParams["font.sans-serif"] = system_chinese_fonts + [
|
||||
"DejaVu Sans",
|
||||
"Arial Unicode MS",
|
||||
]
|
||||
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
plt.rcParams["font.family"] = "sans-serif"
|
||||
|
||||
# 在shell中也设置相同的字体配置
|
||||
font_list_str = str(
|
||||
system_chinese_fonts + ["DejaVu Sans", "Arial Unicode MS"]
|
||||
)
|
||||
self.shell.run_cell(
|
||||
f"""
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.font_manager as fm
|
||||
|
||||
# 设置中文字体
|
||||
plt.rcParams['font.sans-serif'] = {font_list_str}
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
plt.rcParams['font.family'] = 'sans-serif'
|
||||
|
||||
# 确保matplotlib缓存目录可写
|
||||
import os
|
||||
cache_dir = os.path.expanduser('~/.matplotlib')
|
||||
if not os.path.exists(cache_dir):
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
os.environ['MPLCONFIGDIR'] = cache_dir
|
||||
"""
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"设置中文字体失败: {e}")
|
||||
# 即使失败也要设置基本的matplotlib配置
|
||||
try:
|
||||
matplotlib.use("Agg")
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
except:
|
||||
pass
|
||||
|
||||
def _setup_common_imports(self):
|
||||
"""预导入常用库"""
|
||||
common_imports = """
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import duckdb
|
||||
import os
|
||||
import json
|
||||
from IPython.display import display
|
||||
"""
|
||||
try:
|
||||
self.shell.run_cell(common_imports)
|
||||
# 确保display函数在shell的用户命名空间中可用
|
||||
from IPython.display import display
|
||||
|
||||
self.shell.user_ns["display"] = display
|
||||
except Exception as e:
|
||||
print(f"预导入库失败: {e}")
|
||||
|
||||
def _check_code_safety(self, code: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
检查代码安全性,限制导入的库
|
||||
|
||||
Returns:
|
||||
(is_safe, error_message)
|
||||
"""
|
||||
try:
|
||||
tree = ast.parse(code)
|
||||
except SyntaxError as e:
|
||||
return False, f"语法错误: {e}"
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
if alias.name not in self.ALLOWED_IMPORTS:
|
||||
return False, f"不允许的导入: {alias.name}"
|
||||
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
if node.module not in self.ALLOWED_IMPORTS:
|
||||
return False, f"不允许的导入: {node.module}"
|
||||
|
||||
# 检查属性访问(防止通过os.system等方式绕过)
|
||||
elif isinstance(node, ast.Attribute):
|
||||
# 检查是否访问了os模块的属性
|
||||
if isinstance(node.value, ast.Name) and node.value.id == "os":
|
||||
# 允许的os子模块和函数白名单
|
||||
allowed_os_attributes = {
|
||||
"path", "environ", "getcwd", "listdir", "makedirs", "mkdir", "remove", "rmdir",
|
||||
"path.join", "path.exists", "path.abspath", "path.dirname",
|
||||
"path.basename", "path.splitext", "path.isdir", "path.isfile",
|
||||
"sep", "name", "linesep", "stat", "getpid"
|
||||
}
|
||||
|
||||
# 检查直接属性访问 (如 os.getcwd)
|
||||
if node.attr not in allowed_os_attributes:
|
||||
# 进一步检查如果是 os.path.xxx 这种形式
|
||||
# Note: ast.Attribute 嵌套结构比较复杂,简单处理只允许 os.path 和上述白名单
|
||||
if node.attr == "path":
|
||||
pass # 允许访问 os.path
|
||||
else:
|
||||
return False, f"不允许的os属性访问: os.{node.attr}"
|
||||
|
||||
# 检查危险函数调用
|
||||
elif isinstance(node, ast.Call):
|
||||
if isinstance(node.func, ast.Name):
|
||||
if node.func.id in ["exec", "eval", "__import__"]:
|
||||
return False, f"不允许的函数调用: {node.func.id}"
|
||||
|
||||
return True, ""
|
||||
|
||||
def get_current_figures_info(self) -> List[Dict[str, Any]]:
|
||||
"""获取当前matplotlib图形信息,但不自动保存"""
|
||||
figures_info = []
|
||||
|
||||
# 获取当前所有图形
|
||||
fig_nums = plt.get_fignums()
|
||||
|
||||
for fig_num in fig_nums:
|
||||
fig = plt.figure(fig_num)
|
||||
if fig.get_axes(): # 只处理有内容的图形
|
||||
figures_info.append(
|
||||
{
|
||||
"figure_number": fig_num,
|
||||
"axes_count": len(fig.get_axes()),
|
||||
"figure_size": fig.get_size_inches().tolist(),
|
||||
"has_content": True,
|
||||
}
|
||||
)
|
||||
|
||||
return figures_info
|
||||
|
||||
def _format_table_output(self, obj: Any) -> str:
|
||||
"""格式化表格输出,限制行数"""
|
||||
if hasattr(obj, "shape") and hasattr(obj, "head"): # pandas DataFrame
|
||||
rows, cols = obj.shape
|
||||
print(f"\n数据表形状: {rows}行 x {cols}列")
|
||||
print(f"列名: {list(obj.columns)}")
|
||||
|
||||
if rows <= 15:
|
||||
return str(obj)
|
||||
else:
|
||||
head_part = obj.head(5)
|
||||
tail_part = obj.tail(5)
|
||||
return f"{head_part}\n...\n(省略 {rows-10} 行)\n...\n{tail_part}"
|
||||
|
||||
return str(obj)
|
||||
|
||||
def execute_code(self, code: str) -> Dict[str, Any]:
|
||||
"""
|
||||
执行代码并返回结果
|
||||
|
||||
Args:
|
||||
code: 要执行的Python代码
|
||||
|
||||
Returns:
|
||||
{
|
||||
'success': bool,
|
||||
'output': str,
|
||||
'error': str,
|
||||
'variables': Dict[str, Any] # 新生成的重要变量
|
||||
}
|
||||
"""
|
||||
# 检查代码安全性
|
||||
is_safe, safety_error = self._check_code_safety(code)
|
||||
if not is_safe:
|
||||
return {
|
||||
"success": False,
|
||||
"output": "",
|
||||
"error": f"代码安全检查失败: {safety_error}",
|
||||
"variables": {},
|
||||
}
|
||||
|
||||
# 记录执行前的变量
|
||||
vars_before = set(self.shell.user_ns.keys())
|
||||
|
||||
try:
|
||||
# 使用IPython的capture_output来捕获所有输出
|
||||
with capture_output() as captured:
|
||||
result = self.shell.run_cell(code)
|
||||
|
||||
# 检查执行结果
|
||||
if result.error_before_exec:
|
||||
error_msg = str(result.error_before_exec)
|
||||
return {
|
||||
"success": False,
|
||||
"output": captured.stdout,
|
||||
"error": f"执行前错误: {error_msg}",
|
||||
"variables": {},
|
||||
}
|
||||
|
||||
if result.error_in_exec:
|
||||
error_msg = str(result.error_in_exec)
|
||||
return {
|
||||
"success": False,
|
||||
"output": captured.stdout,
|
||||
"error": f"执行错误: {error_msg}",
|
||||
"variables": {},
|
||||
}
|
||||
|
||||
# 获取输出
|
||||
output = captured.stdout
|
||||
|
||||
# 如果有返回值,添加到输出
|
||||
if result.result is not None:
|
||||
formatted_result = self._format_table_output(result.result)
|
||||
output += f"\n{formatted_result}"
|
||||
# 记录新产生的重要变量(简化版本)
|
||||
vars_after = set(self.shell.user_ns.keys())
|
||||
new_vars = vars_after - vars_before
|
||||
|
||||
# 只记录新创建的DataFrame等重要数据结构
|
||||
important_new_vars = {}
|
||||
for var_name in new_vars:
|
||||
if not var_name.startswith("_"):
|
||||
try:
|
||||
var_value = self.shell.user_ns[var_name]
|
||||
if hasattr(var_value, "shape"): # pandas DataFrame, numpy array
|
||||
important_new_vars[var_name] = (
|
||||
f"{type(var_value).__name__} with shape {var_value.shape}"
|
||||
)
|
||||
elif var_name in ["session_output_dir"]: # 重要的配置变量
|
||||
important_new_vars[var_name] = str(var_value)
|
||||
except:
|
||||
pass
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"output": output,
|
||||
"error": "",
|
||||
"variables": important_new_vars,
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"output": captured.stdout if "captured" in locals() else "",
|
||||
"error": f"执行异常: {str(e)}\n{traceback.format_exc()}",
|
||||
"variables": {},
|
||||
}
|
||||
|
||||
def reset_environment(self):
|
||||
"""重置执行环境"""
|
||||
self.shell.reset()
|
||||
self._setup_common_imports()
|
||||
self._setup_chinese_font()
|
||||
plt.close("all")
|
||||
self.image_counter = 0
|
||||
|
||||
def set_variable(self, name: str, value: Any):
|
||||
"""设置执行环境中的变量"""
|
||||
self.shell.user_ns[name] = value
|
||||
|
||||
def get_environment_info(self) -> str:
|
||||
"""获取当前执行环境的变量信息,用于系统提示词"""
|
||||
info_parts = []
|
||||
|
||||
# 获取重要的数据变量
|
||||
important_vars = {}
|
||||
for var_name, var_value in self.shell.user_ns.items():
|
||||
if not var_name.startswith("_") and var_name not in [
|
||||
"In",
|
||||
"Out",
|
||||
"get_ipython",
|
||||
"exit",
|
||||
"quit",
|
||||
]:
|
||||
try:
|
||||
if hasattr(var_value, "shape"): # pandas DataFrame, numpy array
|
||||
important_vars[var_name] = (
|
||||
f"{type(var_value).__name__} with shape {var_value.shape}"
|
||||
)
|
||||
elif var_name in ["session_output_dir"]: # 重要的路径变量
|
||||
important_vars[var_name] = str(var_value)
|
||||
elif (
|
||||
isinstance(var_value, (int, float, str, bool))
|
||||
and len(str(var_value)) < 100
|
||||
):
|
||||
important_vars[var_name] = (
|
||||
f"{type(var_value).__name__}: {var_value}"
|
||||
)
|
||||
elif hasattr(var_value, "__module__") and var_value.__module__ in [
|
||||
"pandas",
|
||||
"numpy",
|
||||
"matplotlib.pyplot",
|
||||
]:
|
||||
important_vars[var_name] = f"导入的模块: {var_value.__module__}"
|
||||
except:
|
||||
continue
|
||||
|
||||
if important_vars:
|
||||
info_parts.append("当前环境变量:")
|
||||
for var_name, var_info in important_vars.items():
|
||||
info_parts.append(f"- {var_name}: {var_info}")
|
||||
else:
|
||||
info_parts.append("当前环境已预装pandas, numpy, matplotlib等库")
|
||||
|
||||
# 添加输出目录信息
|
||||
if "session_output_dir" in self.shell.user_ns:
|
||||
info_parts.append(
|
||||
f"图片保存目录: session_output_dir = '{self.shell.user_ns['session_output_dir']}'"
|
||||
)
|
||||
|
||||
return "\n".join(info_parts)
|
||||
15
utils/create_session_dir.py
Normal file
15
utils/create_session_dir.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def create_session_output_dir(base_output_dir, user_input: str) -> str:
|
||||
"""为本次分析创建独立的输出目录"""
|
||||
|
||||
# 使用当前时间创建唯一的会话目录名(格式:YYYYMMDD_HHMMSS)
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
session_id = timestamp
|
||||
dir_name = f"session_{session_id}"
|
||||
session_dir = os.path.join(base_output_dir, dir_name)
|
||||
os.makedirs(session_dir, exist_ok=True)
|
||||
|
||||
return session_dir
|
||||
90
utils/data_loader.py
Normal file
90
utils/data_loader.py
Normal file
@@ -0,0 +1,90 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import pandas as pd
|
||||
import io
|
||||
|
||||
def load_and_profile_data(file_paths: list) -> str:
|
||||
"""
|
||||
加载数据并生成数据画像
|
||||
|
||||
Args:
|
||||
file_paths: 文件路径列表
|
||||
|
||||
Returns:
|
||||
包含数据画像的Markdown字符串
|
||||
"""
|
||||
profile_summary = "# 数据画像报告 (Data Profile)\n\n"
|
||||
|
||||
if not file_paths:
|
||||
return profile_summary + "未提供数据文件。"
|
||||
|
||||
for file_path in file_paths:
|
||||
file_name = os.path.basename(file_path)
|
||||
profile_summary += f"## 文件: {file_name}\n\n"
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
profile_summary += f"⚠️ 文件不存在: {file_path}\n\n"
|
||||
continue
|
||||
|
||||
try:
|
||||
# 根据扩展名选择加载方式
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
if ext == '.csv':
|
||||
# 尝试多种编码
|
||||
try:
|
||||
df = pd.read_csv(file_path, encoding='utf-8')
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
df = pd.read_csv(file_path, encoding='gbk')
|
||||
except Exception:
|
||||
df = pd.read_csv(file_path, encoding='latin1')
|
||||
elif ext in ['.xlsx', '.xls']:
|
||||
df = pd.read_excel(file_path)
|
||||
else:
|
||||
profile_summary += f"⚠️ 不支持的文件格式: {ext}\n\n"
|
||||
continue
|
||||
|
||||
# 基础信息
|
||||
rows, cols = df.shape
|
||||
profile_summary += f"- **维度**: {rows} 行 x {cols} 列\n"
|
||||
profile_summary += f"- **列名**: `{', '.join(df.columns)}`\n\n"
|
||||
|
||||
profile_summary += "### 列详细分布:\n"
|
||||
|
||||
# 遍历分析每列
|
||||
for col in df.columns:
|
||||
dtype = df[col].dtype
|
||||
null_count = df[col].isnull().sum()
|
||||
null_ratio = (null_count / rows) * 100
|
||||
|
||||
profile_summary += f"#### {col} ({dtype})\n"
|
||||
if null_count > 0:
|
||||
profile_summary += f"- ⚠️ 空值: {null_count} ({null_ratio:.1f}%)\n"
|
||||
|
||||
# 数值列分析
|
||||
if pd.api.types.is_numeric_dtype(dtype):
|
||||
desc = df[col].describe()
|
||||
profile_summary += f"- 统计: Min={desc['min']:.2f}, Max={desc['max']:.2f}, Mean={desc['mean']:.2f}\n"
|
||||
|
||||
# 文本/分类列分析
|
||||
elif pd.api.types.is_object_dtype(dtype) or pd.api.types.is_categorical_dtype(dtype):
|
||||
unique_count = df[col].nunique()
|
||||
profile_summary += f"- 唯一值数量: {unique_count}\n"
|
||||
|
||||
# 如果唯一值较少(<50)或者看起来是分类数据,显示Top分布
|
||||
# 这对识别“高频问题”至关重要
|
||||
if unique_count > 0:
|
||||
top_n = df[col].value_counts().head(5)
|
||||
top_items_str = ", ".join([f"{k}({v})" for k, v in top_n.items()])
|
||||
profile_summary += f"- **TOP 5 高频值**: {top_items_str}\n"
|
||||
|
||||
# 时间列分析
|
||||
elif pd.api.types.is_datetime64_any_dtype(dtype):
|
||||
profile_summary += f"- 范围: {df[col].min()} 至 {df[col].max()}\n"
|
||||
|
||||
profile_summary += "\n"
|
||||
|
||||
except Exception as e:
|
||||
profile_summary += f"❌ 读取或分析文件失败: {str(e)}\n\n"
|
||||
|
||||
return profile_summary
|
||||
219
utils/execution_session_client.py
Normal file
219
utils/execution_session_client.py
Normal file
@@ -0,0 +1,219 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Client for a per-analysis execution worker subprocess.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import queue
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
class WorkerSessionError(RuntimeError):
|
||||
"""Raised when the execution worker cannot serve a request."""
|
||||
|
||||
|
||||
class WorkerTimeoutError(WorkerSessionError):
|
||||
"""Raised when the worker does not respond within the configured timeout."""
|
||||
|
||||
|
||||
class ExecutionSessionClient:
|
||||
"""Client that proxies CodeExecutor methods to a dedicated worker process."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_dir: str,
|
||||
allowed_files=None,
|
||||
python_executable: str = None,
|
||||
request_timeout_seconds: float = 60.0,
|
||||
startup_timeout_seconds: float = 180.0,
|
||||
):
|
||||
self.output_dir = os.path.abspath(output_dir)
|
||||
self.allowed_files = [os.path.abspath(path) for path in (allowed_files or [])]
|
||||
self.allowed_read_roots = sorted(
|
||||
{os.path.dirname(path) for path in self.allowed_files}
|
||||
)
|
||||
self.python_executable = python_executable or sys.executable
|
||||
self.request_timeout_seconds = request_timeout_seconds
|
||||
self.startup_timeout_seconds = startup_timeout_seconds
|
||||
self._process: Optional[subprocess.Popen] = None
|
||||
self._stderr_handle = None
|
||||
self._start_worker()
|
||||
self._request(
|
||||
"init_session",
|
||||
{
|
||||
"output_dir": self.output_dir,
|
||||
"variables": {
|
||||
"session_output_dir": self.output_dir,
|
||||
"allowed_files": self.allowed_files,
|
||||
"allowed_read_roots": self.allowed_read_roots,
|
||||
},
|
||||
},
|
||||
timeout_seconds=self.startup_timeout_seconds,
|
||||
)
|
||||
|
||||
def execute_code(self, code: str) -> Dict[str, Any]:
|
||||
return self._request("execute_code", {"code": code})
|
||||
|
||||
def set_variable(self, name: str, value: Any) -> None:
|
||||
self._request("set_variable", {"name": name, "value": value})
|
||||
|
||||
def get_environment_info(self) -> str:
|
||||
payload = self._request("get_environment_info", {})
|
||||
return payload.get("environment_info", "")
|
||||
|
||||
def reset_environment(self) -> None:
|
||||
self._request("reset_environment", {})
|
||||
self.set_variable("session_output_dir", self.output_dir)
|
||||
self.set_variable("allowed_files", self.allowed_files)
|
||||
self.set_variable("allowed_read_roots", self.allowed_read_roots)
|
||||
|
||||
def ping(self) -> bool:
|
||||
payload = self._request("ping", {})
|
||||
return bool(payload.get("alive"))
|
||||
|
||||
def close(self) -> None:
|
||||
if self._process is None:
|
||||
return
|
||||
|
||||
try:
|
||||
if self._process.poll() is None:
|
||||
self._request("shutdown", {}, timeout_seconds=5)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
self._teardown_worker()
|
||||
|
||||
def _start_worker(self) -> None:
|
||||
runtime_dir = os.path.join(self.output_dir, ".worker_runtime")
|
||||
mpl_dir = os.path.join(runtime_dir, "mplconfig")
|
||||
ipython_dir = os.path.join(runtime_dir, "ipython")
|
||||
os.makedirs(mpl_dir, exist_ok=True)
|
||||
os.makedirs(ipython_dir, exist_ok=True)
|
||||
|
||||
stderr_log_path = os.path.join(self.output_dir, "execution_worker.log")
|
||||
self._stderr_handle = open(stderr_log_path, "a", encoding="utf-8")
|
||||
|
||||
worker_script = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"execution_worker.py",
|
||||
)
|
||||
env = os.environ.copy()
|
||||
env["MPLCONFIGDIR"] = mpl_dir
|
||||
env["IPYTHONDIR"] = ipython_dir
|
||||
env.setdefault("PYTHONIOENCODING", "utf-8")
|
||||
|
||||
self._process = subprocess.Popen(
|
||||
[self.python_executable, worker_script],
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=self._stderr_handle,
|
||||
text=True,
|
||||
encoding="utf-8",
|
||||
cwd=os.path.abspath(os.path.join(os.path.dirname(__file__), "..")),
|
||||
env=env,
|
||||
bufsize=1,
|
||||
)
|
||||
|
||||
def _request(
|
||||
self,
|
||||
request_type: str,
|
||||
payload: Dict[str, Any],
|
||||
timeout_seconds: float = None,
|
||||
) -> Dict[str, Any]:
|
||||
if self._process is None or self._process.stdin is None or self._process.stdout is None:
|
||||
raise WorkerSessionError("执行子进程未启动")
|
||||
if self._process.poll() is not None:
|
||||
raise WorkerSessionError(
|
||||
f"执行子进程已退出,退出码: {self._process.returncode}"
|
||||
)
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
message = {
|
||||
"request_id": request_id,
|
||||
"type": request_type,
|
||||
"payload": payload,
|
||||
}
|
||||
|
||||
try:
|
||||
self._process.stdin.write(json.dumps(message, ensure_ascii=False) + "\n")
|
||||
self._process.stdin.flush()
|
||||
except BrokenPipeError as exc:
|
||||
self._teardown_worker()
|
||||
raise WorkerSessionError("执行子进程通信中断") from exc
|
||||
|
||||
effective_timeout = (
|
||||
self.request_timeout_seconds
|
||||
if timeout_seconds is None
|
||||
else timeout_seconds
|
||||
)
|
||||
response_line = self._read_response_line(effective_timeout)
|
||||
if not response_line:
|
||||
if self._process.poll() is not None:
|
||||
exit_code = self._process.returncode
|
||||
self._teardown_worker()
|
||||
raise WorkerSessionError(f"执行子进程已异常退出,退出码: {exit_code}")
|
||||
raise WorkerSessionError("执行子进程未返回响应")
|
||||
|
||||
try:
|
||||
response = json.loads(response_line)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise WorkerSessionError(f"执行子进程返回了无效JSON: {response_line}") from exc
|
||||
|
||||
if response.get("request_id") != request_id:
|
||||
raise WorkerSessionError("执行子进程响应 request_id 不匹配")
|
||||
|
||||
if response.get("status") != "ok":
|
||||
raise WorkerSessionError(response.get("error", "执行子进程返回未知错误"))
|
||||
|
||||
return response.get("payload", {})
|
||||
|
||||
def _read_response_line(self, timeout_seconds: float) -> str:
|
||||
assert self._process is not None and self._process.stdout is not None
|
||||
|
||||
response_queue: queue.Queue = queue.Queue(maxsize=1)
|
||||
|
||||
def _reader() -> None:
|
||||
try:
|
||||
response_queue.put((True, self._process.stdout.readline()))
|
||||
except Exception as exc:
|
||||
response_queue.put((False, exc))
|
||||
|
||||
thread = threading.Thread(target=_reader, daemon=True)
|
||||
thread.start()
|
||||
|
||||
try:
|
||||
success, value = response_queue.get(timeout=timeout_seconds)
|
||||
except queue.Empty as exc:
|
||||
self._teardown_worker(force=True)
|
||||
raise WorkerTimeoutError(
|
||||
f"执行子进程在 {timeout_seconds:.1f} 秒内未响应,已终止当前会话"
|
||||
) from exc
|
||||
|
||||
if success:
|
||||
return value
|
||||
|
||||
self._teardown_worker()
|
||||
raise WorkerSessionError(f"读取执行子进程响应失败: {value}")
|
||||
|
||||
def _teardown_worker(self, force: bool = False) -> None:
|
||||
if self._process is not None and self._process.poll() is None:
|
||||
self._process.terminate()
|
||||
try:
|
||||
self._process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
self._process.kill()
|
||||
self._process.wait(timeout=5)
|
||||
|
||||
if self._stderr_handle is not None:
|
||||
self._stderr_handle.close()
|
||||
self._stderr_handle = None
|
||||
|
||||
self._process = None
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
321
utils/execution_worker.py
Normal file
321
utils/execution_worker.py
Normal file
@@ -0,0 +1,321 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Subprocess worker that hosts a single CodeExecutor instance for one analysis session.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from contextlib import redirect_stderr, redirect_stdout
|
||||
from io import StringIO
|
||||
from typing import Any, Dict, Iterable, Optional
|
||||
|
||||
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
if PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, PROJECT_ROOT)
|
||||
|
||||
|
||||
class WorkerProtocolError(RuntimeError):
|
||||
"""Raised when the worker receives an invalid protocol message."""
|
||||
|
||||
|
||||
class FileAccessPolicy:
|
||||
"""Controls which files the worker may read and where it may write outputs."""
|
||||
|
||||
def __init__(self):
|
||||
self.allowed_reads = set()
|
||||
self.allowed_read_roots = set()
|
||||
self.allowed_write_root = ""
|
||||
|
||||
@staticmethod
|
||||
def _normalize(path: Any) -> str:
|
||||
if isinstance(path, Path):
|
||||
path = str(path)
|
||||
elif hasattr(path, "__fspath__"):
|
||||
path = os.fspath(path)
|
||||
elif not isinstance(path, str):
|
||||
raise TypeError(f"不支持的路径类型: {type(path).__name__}")
|
||||
return os.path.realpath(os.path.abspath(path))
|
||||
|
||||
def configure(
|
||||
self,
|
||||
allowed_reads: Iterable[Any],
|
||||
allowed_write_root: Any,
|
||||
allowed_read_roots: Optional[Iterable[Any]] = None,
|
||||
) -> None:
|
||||
self.allowed_reads = {
|
||||
self._normalize(path) for path in allowed_reads if path
|
||||
}
|
||||
explicit_roots = {
|
||||
self._normalize(path) for path in (allowed_read_roots or []) if path
|
||||
}
|
||||
derived_roots = {
|
||||
os.path.dirname(path) for path in self.allowed_reads
|
||||
}
|
||||
self.allowed_read_roots = explicit_roots | derived_roots
|
||||
self.allowed_write_root = (
|
||||
self._normalize(allowed_write_root) if allowed_write_root else ""
|
||||
)
|
||||
|
||||
def ensure_readable(self, path: Any) -> str:
|
||||
normalized_path = self._normalize(path)
|
||||
if normalized_path in self.allowed_reads:
|
||||
return normalized_path
|
||||
if self._is_within_read_roots(normalized_path):
|
||||
return normalized_path
|
||||
if self._is_within_write_root(normalized_path):
|
||||
return normalized_path
|
||||
raise PermissionError(f"禁止读取未授权文件: {normalized_path}")
|
||||
|
||||
def ensure_writable(self, path: Any) -> str:
|
||||
normalized_path = self._normalize(path)
|
||||
if self._is_within_write_root(normalized_path):
|
||||
return normalized_path
|
||||
raise PermissionError(f"禁止写入会话目录之外的路径: {normalized_path}")
|
||||
|
||||
def _is_within_write_root(self, normalized_path: str) -> bool:
|
||||
if not self.allowed_write_root:
|
||||
return False
|
||||
return normalized_path == self.allowed_write_root or normalized_path.startswith(
|
||||
self.allowed_write_root + os.sep
|
||||
)
|
||||
|
||||
def _is_within_read_roots(self, normalized_path: str) -> bool:
|
||||
for root in self.allowed_read_roots:
|
||||
if normalized_path == root or normalized_path.startswith(root + os.sep):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _write_message(message: Dict[str, Any]) -> None:
|
||||
sys.stdout.write(json.dumps(message, ensure_ascii=False) + "\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def _write_log(text: str) -> None:
|
||||
if not text:
|
||||
return
|
||||
sys.stderr.write(text)
|
||||
if not text.endswith("\n"):
|
||||
sys.stderr.write("\n")
|
||||
sys.stderr.flush()
|
||||
|
||||
|
||||
class ExecutionWorker:
|
||||
"""JSON-line protocol wrapper around CodeExecutor."""
|
||||
|
||||
def __init__(self):
|
||||
self.executor = None
|
||||
self.access_policy = FileAccessPolicy()
|
||||
self._patches_installed = False
|
||||
|
||||
def handle_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
|
||||
request_id = request.get("request_id", "")
|
||||
request_type = request.get("type")
|
||||
payload = request.get("payload", {})
|
||||
|
||||
try:
|
||||
if request_type == "ping":
|
||||
return self._ok(request_id, {"alive": True})
|
||||
if request_type == "init_session":
|
||||
return self._handle_init_session(request_id, payload)
|
||||
if request_type == "execute_code":
|
||||
self._require_executor()
|
||||
return self._ok(
|
||||
request_id,
|
||||
self.executor.execute_code(payload.get("code", "")),
|
||||
)
|
||||
if request_type == "set_variable":
|
||||
self._require_executor()
|
||||
self._handle_set_variable(payload["name"], payload["value"])
|
||||
return self._ok(request_id, {"set": True})
|
||||
if request_type == "get_environment_info":
|
||||
self._require_executor()
|
||||
return self._ok(
|
||||
request_id,
|
||||
{"environment_info": self.executor.get_environment_info()},
|
||||
)
|
||||
if request_type == "reset_environment":
|
||||
self._require_executor()
|
||||
self.executor.reset_environment()
|
||||
return self._ok(request_id, {"reset": True})
|
||||
if request_type == "shutdown":
|
||||
return self._ok(request_id, {"shutdown": True})
|
||||
|
||||
raise WorkerProtocolError(f"未知请求类型: {request_type}")
|
||||
except Exception as exc:
|
||||
return {
|
||||
"request_id": request_id,
|
||||
"status": "error",
|
||||
"error": str(exc),
|
||||
"traceback": traceback.format_exc(),
|
||||
}
|
||||
|
||||
def _handle_init_session(
|
||||
self, request_id: str, payload: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
output_dir = payload.get("output_dir")
|
||||
if not output_dir:
|
||||
raise WorkerProtocolError("init_session 缺少 output_dir")
|
||||
|
||||
from utils.code_executor import CodeExecutor
|
||||
|
||||
self.executor = CodeExecutor(output_dir)
|
||||
self.access_policy.configure(
|
||||
payload.get("variables", {}).get("allowed_files", []),
|
||||
output_dir,
|
||||
payload.get("variables", {}).get("allowed_read_roots", []),
|
||||
)
|
||||
self._install_file_guards()
|
||||
|
||||
for name, value in payload.get("variables", {}).items():
|
||||
self.executor.set_variable(name, value)
|
||||
|
||||
return self._ok(request_id, {"initialized": True})
|
||||
|
||||
def _install_file_guards(self) -> None:
|
||||
if self._patches_installed:
|
||||
return
|
||||
|
||||
import builtins
|
||||
import matplotlib.figure
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
|
||||
policy = self.access_policy
|
||||
|
||||
original_open = builtins.open
|
||||
original_read_csv = pd.read_csv
|
||||
original_read_excel = pd.read_excel
|
||||
original_to_csv = pd.DataFrame.to_csv
|
||||
original_to_excel = pd.DataFrame.to_excel
|
||||
original_plt_savefig = plt.savefig
|
||||
original_figure_savefig = matplotlib.figure.Figure.savefig
|
||||
|
||||
def guarded_open(file, mode="r", *args, **kwargs):
|
||||
if isinstance(file, (str, Path)) or hasattr(file, "__fspath__"):
|
||||
if any(flag in mode for flag in ("w", "a", "x", "+")):
|
||||
policy.ensure_writable(file)
|
||||
else:
|
||||
policy.ensure_readable(file)
|
||||
return original_open(file, mode, *args, **kwargs)
|
||||
|
||||
def guarded_read_csv(filepath_or_buffer, *args, **kwargs):
|
||||
if isinstance(filepath_or_buffer, (str, Path)) or hasattr(
|
||||
filepath_or_buffer, "__fspath__"
|
||||
):
|
||||
policy.ensure_readable(filepath_or_buffer)
|
||||
return original_read_csv(filepath_or_buffer, *args, **kwargs)
|
||||
|
||||
def guarded_read_excel(io, *args, **kwargs):
|
||||
if isinstance(io, (str, Path)) or hasattr(io, "__fspath__"):
|
||||
policy.ensure_readable(io)
|
||||
return original_read_excel(io, *args, **kwargs)
|
||||
|
||||
def guarded_to_csv(df, path_or_buf=None, *args, **kwargs):
|
||||
if isinstance(path_or_buf, (str, Path)) or hasattr(path_or_buf, "__fspath__"):
|
||||
policy.ensure_writable(path_or_buf)
|
||||
return original_to_csv(df, path_or_buf, *args, **kwargs)
|
||||
|
||||
def guarded_to_excel(df, excel_writer, *args, **kwargs):
|
||||
if isinstance(excel_writer, (str, Path)) or hasattr(excel_writer, "__fspath__"):
|
||||
policy.ensure_writable(excel_writer)
|
||||
return original_to_excel(df, excel_writer, *args, **kwargs)
|
||||
|
||||
def guarded_savefig(*args, **kwargs):
|
||||
target = args[0] if args else kwargs.get("fname")
|
||||
if target is not None and (
|
||||
isinstance(target, (str, Path)) or hasattr(target, "__fspath__")
|
||||
):
|
||||
policy.ensure_writable(target)
|
||||
return original_plt_savefig(*args, **kwargs)
|
||||
|
||||
def guarded_figure_savefig(fig, fname, *args, **kwargs):
|
||||
if isinstance(fname, (str, Path)) or hasattr(fname, "__fspath__"):
|
||||
policy.ensure_writable(fname)
|
||||
return original_figure_savefig(fig, fname, *args, **kwargs)
|
||||
|
||||
builtins.open = guarded_open
|
||||
pd.read_csv = guarded_read_csv
|
||||
pd.read_excel = guarded_read_excel
|
||||
pd.DataFrame.to_csv = guarded_to_csv
|
||||
pd.DataFrame.to_excel = guarded_to_excel
|
||||
plt.savefig = guarded_savefig
|
||||
matplotlib.figure.Figure.savefig = guarded_figure_savefig
|
||||
|
||||
self._patches_installed = True
|
||||
|
||||
def _require_executor(self) -> None:
|
||||
if self.executor is None:
|
||||
raise WorkerProtocolError("执行会话尚未初始化")
|
||||
|
||||
def _handle_set_variable(self, name: str, value: Any) -> None:
|
||||
self.executor.set_variable(name, value)
|
||||
|
||||
if name == "allowed_files":
|
||||
self.access_policy.configure(
|
||||
value,
|
||||
self.access_policy.allowed_write_root,
|
||||
self.access_policy.allowed_read_roots,
|
||||
)
|
||||
elif name == "allowed_read_roots":
|
||||
self.access_policy.configure(
|
||||
self.access_policy.allowed_reads,
|
||||
self.access_policy.allowed_write_root,
|
||||
value,
|
||||
)
|
||||
elif name == "session_output_dir":
|
||||
self.access_policy.configure(
|
||||
self.access_policy.allowed_reads,
|
||||
value,
|
||||
self.access_policy.allowed_read_roots,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _ok(request_id: str, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {
|
||||
"request_id": request_id,
|
||||
"status": "ok",
|
||||
"payload": payload,
|
||||
}
|
||||
|
||||
|
||||
def main() -> int:
|
||||
worker = ExecutionWorker()
|
||||
|
||||
for raw_line in sys.stdin:
|
||||
raw_line = raw_line.strip()
|
||||
if not raw_line:
|
||||
continue
|
||||
|
||||
try:
|
||||
request = json.loads(raw_line)
|
||||
except json.JSONDecodeError as exc:
|
||||
_write_message(
|
||||
{
|
||||
"request_id": "",
|
||||
"status": "error",
|
||||
"error": f"无效JSON请求: {exc}",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
captured_stdout = StringIO()
|
||||
captured_stderr = StringIO()
|
||||
with redirect_stdout(captured_stdout), redirect_stderr(captured_stderr):
|
||||
response = worker.handle_request(request)
|
||||
|
||||
_write_log(captured_stdout.getvalue())
|
||||
_write_log(captured_stderr.getvalue())
|
||||
_write_message(response)
|
||||
|
||||
if request.get("type") == "shutdown" and response.get("status") == "ok":
|
||||
return 0
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
38
utils/extract_code.py
Normal file
38
utils/extract_code.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from typing import Optional
|
||||
import yaml
|
||||
|
||||
|
||||
def extract_code_from_response(response: str) -> Optional[str]:
|
||||
"""从LLM响应中提取代码"""
|
||||
try:
|
||||
# 尝试解析YAML
|
||||
if '```yaml' in response:
|
||||
start = response.find('```yaml') + 7
|
||||
end = response.find('```', start)
|
||||
yaml_content = response[start:end].strip()
|
||||
elif '```' in response:
|
||||
start = response.find('```') + 3
|
||||
end = response.find('```', start)
|
||||
yaml_content = response[start:end].strip()
|
||||
else:
|
||||
yaml_content = response.strip()
|
||||
|
||||
yaml_data = yaml.safe_load(yaml_content)
|
||||
if 'code' in yaml_data:
|
||||
return yaml_data['code']
|
||||
except:
|
||||
pass
|
||||
|
||||
# 如果YAML解析失败,尝试提取```python代码块
|
||||
if '```python' in response:
|
||||
start = response.find('```python') + 9
|
||||
end = response.find('```', start)
|
||||
if end != -1:
|
||||
return response[start:end].strip()
|
||||
elif '```' in response:
|
||||
start = response.find('```') + 3
|
||||
end = response.find('```', start)
|
||||
if end != -1:
|
||||
return response[start:end].strip()
|
||||
|
||||
return None
|
||||
230
utils/fallback_openai_client.py
Normal file
230
utils/fallback_openai_client.py
Normal file
@@ -0,0 +1,230 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import asyncio
|
||||
from typing import Optional, Any, Mapping, Dict
|
||||
from openai import AsyncOpenAI, APIStatusError, APIConnectionError, APITimeoutError, APIError
|
||||
from openai.types.chat import ChatCompletion
|
||||
|
||||
class AsyncFallbackOpenAIClient:
|
||||
"""
|
||||
一个支持备用 API 自动切换的异步 OpenAI 客户端。
|
||||
当主 API 调用因特定错误(如内容过滤)失败时,会自动尝试使用备用 API。
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
primary_api_key: str,
|
||||
primary_base_url: str,
|
||||
primary_model_name: str,
|
||||
fallback_api_key: Optional[str] = None,
|
||||
fallback_base_url: Optional[str] = None,
|
||||
fallback_model_name: Optional[str] = None,
|
||||
primary_client_args: Optional[Dict[str, Any]] = None,
|
||||
fallback_client_args: Optional[Dict[str, Any]] = None,
|
||||
content_filter_error_code: str = "1301", # 特定于 Zhipu 的内容过滤错误代码
|
||||
content_filter_error_field: str = "contentFilter", # 特定于 Zhipu 的内容过滤错误字段
|
||||
max_retries_primary: int = 1, # 主API重试次数
|
||||
max_retries_fallback: int = 1, # 备用API重试次数
|
||||
retry_delay_seconds: float = 1.0 # 重试延迟时间
|
||||
):
|
||||
"""
|
||||
初始化 AsyncFallbackOpenAIClient。
|
||||
|
||||
Args:
|
||||
primary_api_key: 主 API 的密钥。
|
||||
primary_base_url: 主 API 的基础 URL。
|
||||
primary_model_name: 主 API 使用的模型名称。
|
||||
fallback_api_key: 备用 API 的密钥 (可选)。
|
||||
fallback_base_url: 备用 API 的基础 URL (可选)。
|
||||
fallback_model_name: 备用 API 使用的模型名称 (可选)。
|
||||
primary_client_args: 传递给主 AsyncOpenAI 客户端的其他参数。
|
||||
fallback_client_args: 传递给备用 AsyncOpenAI 客户端的其他参数。
|
||||
content_filter_error_code: 触发回退的内容过滤错误的特定错误代码。
|
||||
content_filter_error_field: 触发回退的内容过滤错误中存在的字段名。
|
||||
max_retries_primary: 主 API 失败时的最大重试次数。
|
||||
max_retries_fallback: 备用 API 失败时的最大重试次数。
|
||||
retry_delay_seconds: 重试前的延迟时间(秒)。
|
||||
"""
|
||||
if not primary_api_key or not primary_base_url:
|
||||
raise ValueError("主 API 密钥和基础 URL 不能为空。")
|
||||
|
||||
_primary_args = primary_client_args or {}
|
||||
self.primary_client = AsyncOpenAI(api_key=primary_api_key, base_url=primary_base_url, **_primary_args)
|
||||
self.primary_model_name = primary_model_name
|
||||
|
||||
self.fallback_client: Optional[AsyncOpenAI] = None
|
||||
self.fallback_model_name: Optional[str] = None
|
||||
if fallback_api_key and fallback_base_url and fallback_model_name:
|
||||
_fallback_args = fallback_client_args or {}
|
||||
self.fallback_client = AsyncOpenAI(api_key=fallback_api_key, base_url=fallback_base_url, **_fallback_args)
|
||||
self.fallback_model_name = fallback_model_name
|
||||
else:
|
||||
print("⚠️ 警告: 未完全配置备用 API 客户端。如果主 API 失败,将无法进行回退。")
|
||||
|
||||
self.content_filter_error_code = content_filter_error_code
|
||||
self.content_filter_error_field = content_filter_error_field
|
||||
self.max_retries_primary = max_retries_primary
|
||||
self.max_retries_fallback = max_retries_fallback
|
||||
self.retry_delay_seconds = retry_delay_seconds
|
||||
self._closed = False
|
||||
|
||||
async def _attempt_api_call(
|
||||
self,
|
||||
client: AsyncOpenAI,
|
||||
model_name: str,
|
||||
messages: list[Mapping[str, Any]],
|
||||
max_retries: int,
|
||||
api_name: str,
|
||||
**kwargs: Any
|
||||
) -> ChatCompletion:
|
||||
"""
|
||||
尝试调用指定的 OpenAI API 客户端,并进行重试。
|
||||
"""
|
||||
last_exception = None
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
# print(f"尝试使用 {api_name} API ({client.base_url}) 模型: {kwargs.get('model', model_name)}, 第 {attempt + 1} 次尝试")
|
||||
completion = await client.chat.completions.create(
|
||||
model=kwargs.pop('model', model_name),
|
||||
messages=messages,
|
||||
**kwargs
|
||||
)
|
||||
return completion
|
||||
except (APIConnectionError, APITimeoutError) as e: # 通常可以重试的网络错误
|
||||
last_exception = e
|
||||
print(f"⚠️ {api_name} API 调用时发生可重试错误 ({type(e).__name__}): {e}. 尝试次数 {attempt + 1}/{max_retries + 1}")
|
||||
if attempt < max_retries:
|
||||
await asyncio.sleep(self.retry_delay_seconds * (attempt + 1)) # 增加延迟
|
||||
else:
|
||||
print(f"❌ {api_name} API 在达到最大重试次数后仍然失败。")
|
||||
except APIStatusError as e: # API 返回的特定状态码错误
|
||||
is_content_filter_error = False
|
||||
if e.status_code == 400:
|
||||
try:
|
||||
error_json = e.response.json()
|
||||
error_details = error_json.get("error", {})
|
||||
if (error_details.get("code") == self.content_filter_error_code and
|
||||
self.content_filter_error_field in error_json):
|
||||
is_content_filter_error = True
|
||||
except Exception:
|
||||
pass # 解析错误响应失败,不认为是内容过滤错误
|
||||
|
||||
if is_content_filter_error and api_name == "主": # 如果是主 API 的内容过滤错误,则直接抛出以便回退
|
||||
raise e
|
||||
|
||||
last_exception = e
|
||||
print(f"⚠️ {api_name} API 调用时发生 APIStatusError ({e.status_code}): {e}. 尝试次数 {attempt + 1}/{max_retries + 1}")
|
||||
if attempt < max_retries:
|
||||
await asyncio.sleep(self.retry_delay_seconds * (attempt + 1))
|
||||
else:
|
||||
print(f"❌ {api_name} API 在达到最大重试次数后仍然失败 (APIStatusError)。")
|
||||
except APIError as e: # 其他不可轻易重试的 OpenAI 错误
|
||||
last_exception = e
|
||||
print(f"❌ {api_name} API 调用时发生不可重试错误 ({type(e).__name__}): {e}")
|
||||
break # 不再重试此类错误
|
||||
|
||||
if last_exception:
|
||||
raise last_exception
|
||||
raise RuntimeError(f"{api_name} API 调用意外失败。") # 理论上不应到达这里
|
||||
|
||||
async def chat_completions_create(
|
||||
self,
|
||||
messages: list[Mapping[str, Any]],
|
||||
**kwargs: Any # 用于传递其他 OpenAI 参数,如 max_tokens, temperature 等。
|
||||
) -> ChatCompletion:
|
||||
"""
|
||||
使用主 API 创建聊天补全,如果发生特定内容过滤错误或主 API 调用失败,则回退到备用 API。
|
||||
支持对主 API 和备用 API 的可重试错误进行重试。
|
||||
|
||||
Args:
|
||||
messages: OpenAI API 的消息列表。
|
||||
**kwargs: 传递给 OpenAI API 调用的其他参数。
|
||||
|
||||
Returns:
|
||||
ChatCompletion 对象。
|
||||
|
||||
Raises:
|
||||
APIError: 如果主 API 和备用 API (如果尝试) 都返回 API 错误。
|
||||
RuntimeError: 如果客户端已关闭。
|
||||
"""
|
||||
if self._closed:
|
||||
raise RuntimeError("客户端已关闭。")
|
||||
|
||||
try:
|
||||
completion = await self._attempt_api_call(
|
||||
client=self.primary_client,
|
||||
model_name=self.primary_model_name,
|
||||
messages=messages,
|
||||
max_retries=self.max_retries_primary,
|
||||
api_name="主",
|
||||
**kwargs.copy()
|
||||
)
|
||||
return completion
|
||||
except APIStatusError as e_primary:
|
||||
is_content_filter_error = False
|
||||
if e_primary.status_code == 400:
|
||||
try:
|
||||
error_json = e_primary.response.json()
|
||||
error_details = error_json.get("error", {})
|
||||
if (error_details.get("code") == self.content_filter_error_code and
|
||||
self.content_filter_error_field in error_json):
|
||||
is_content_filter_error = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if is_content_filter_error and self.fallback_client and self.fallback_model_name:
|
||||
print(f"ℹ️ 主 API 内容过滤错误 ({e_primary.status_code})。尝试切换到备用 API ({self.fallback_client.base_url})...")
|
||||
try:
|
||||
fallback_completion = await self._attempt_api_call(
|
||||
client=self.fallback_client,
|
||||
model_name=self.fallback_model_name,
|
||||
messages=messages,
|
||||
max_retries=self.max_retries_fallback,
|
||||
api_name="备用",
|
||||
**kwargs.copy()
|
||||
)
|
||||
print(f"✅ 备用 API 调用成功。")
|
||||
return fallback_completion
|
||||
except APIError as e_fallback:
|
||||
print(f"❌ 备用 API 调用最终失败: {type(e_fallback).__name__} - {e_fallback}")
|
||||
raise e_fallback
|
||||
else:
|
||||
if not (self.fallback_client and self.fallback_model_name and is_content_filter_error):
|
||||
# 如果不是内容过滤错误,或者没有可用的备用API,则记录主API的原始错误
|
||||
print(f"ℹ️ 主 API 错误 ({type(e_primary).__name__}: {e_primary}), 且不满足备用条件或备用API未配置。")
|
||||
raise e_primary
|
||||
except APIError as e_primary_other:
|
||||
print(f"❌ 主 API 调用最终失败 (非内容过滤,错误类型: {type(e_primary_other).__name__}): {e_primary_other}")
|
||||
if self.fallback_client and self.fallback_model_name:
|
||||
print(f"ℹ️ 主 API 失败,尝试切换到备用 API ({self.fallback_client.base_url})...")
|
||||
try:
|
||||
fallback_completion = await self._attempt_api_call(
|
||||
client=self.fallback_client,
|
||||
model_name=self.fallback_model_name,
|
||||
messages=messages,
|
||||
max_retries=self.max_retries_fallback,
|
||||
api_name="备用",
|
||||
**kwargs.copy()
|
||||
)
|
||||
print(f"✅ 备用 API 调用成功。")
|
||||
return fallback_completion
|
||||
except APIError as e_fallback_after_primary_fail:
|
||||
print(f"❌ 备用 API 在主 API 失败后也调用失败: {type(e_fallback_after_primary_fail).__name__} - {e_fallback_after_primary_fail}")
|
||||
raise e_fallback_after_primary_fail
|
||||
else:
|
||||
raise e_primary_other
|
||||
|
||||
async def close(self):
|
||||
"""异步关闭主客户端和备用客户端 (如果存在)。"""
|
||||
if not self._closed:
|
||||
await self.primary_client.close()
|
||||
if self.fallback_client:
|
||||
await self.fallback_client.close()
|
||||
self._closed = True
|
||||
# print("AsyncFallbackOpenAIClient 已关闭。")
|
||||
|
||||
async def __aenter__(self):
|
||||
if self._closed:
|
||||
raise RuntimeError("AsyncFallbackOpenAIClient 不能在关闭后重新进入。请创建一个新实例。")
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.close()
|
||||
25
utils/format_execution_result.py
Normal file
25
utils/format_execution_result.py
Normal file
@@ -0,0 +1,25 @@
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
def format_execution_result(result: Dict[str, Any]) -> str:
|
||||
"""格式化执行结果为用户可读的反馈"""
|
||||
feedback = []
|
||||
|
||||
if result['success']:
|
||||
feedback.append("✅ 代码执行成功")
|
||||
|
||||
if result['output']:
|
||||
feedback.append(f"📊 输出结果:\n{result['output']}")
|
||||
|
||||
if result.get('variables'):
|
||||
feedback.append("📋 新生成的变量:")
|
||||
for var_name, var_info in result['variables'].items():
|
||||
feedback.append(f" - {var_name}: {var_info}")
|
||||
else:
|
||||
feedback.append("❌ 代码执行失败")
|
||||
feedback.append(f"错误信息: {result['error']}")
|
||||
if result['output']:
|
||||
feedback.append(f"部分输出: {result['output']}")
|
||||
|
||||
return "\n".join(feedback)
|
||||
91
utils/llm_helper.py
Normal file
91
utils/llm_helper.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
LLM调用辅助模块
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import yaml
|
||||
from config.llm_config import LLMConfig
|
||||
from utils.fallback_openai_client import AsyncFallbackOpenAIClient
|
||||
|
||||
|
||||
class LLMCallError(RuntimeError):
|
||||
"""Raised when the configured LLM backend cannot complete a request."""
|
||||
|
||||
|
||||
class LLMHelper:
|
||||
"""LLM调用辅助类,支持同步和异步调用"""
|
||||
|
||||
def __init__(self, config: LLMConfig = None):
|
||||
self.config = config or LLMConfig()
|
||||
self.config.validate()
|
||||
self.client = AsyncFallbackOpenAIClient(
|
||||
primary_api_key=self.config.api_key,
|
||||
primary_base_url=self.config.base_url,
|
||||
primary_model_name=self.config.model
|
||||
)
|
||||
|
||||
async def async_call(self, prompt: str, system_prompt: str = None, max_tokens: int = None, temperature: float = None) -> str:
|
||||
"""异步调用LLM"""
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
kwargs = {}
|
||||
if max_tokens is not None:
|
||||
kwargs['max_tokens'] = max_tokens
|
||||
else:
|
||||
kwargs['max_tokens'] = self.config.max_tokens
|
||||
|
||||
if temperature is not None:
|
||||
kwargs['temperature'] = temperature
|
||||
else:
|
||||
kwargs['temperature'] = self.config.temperature
|
||||
|
||||
try:
|
||||
response = await self.client.chat_completions_create(
|
||||
messages=messages,
|
||||
**kwargs
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
raise LLMCallError(f"LLM调用失败: {e}") from e
|
||||
|
||||
def call(self, prompt: str, system_prompt: str = None, max_tokens: int = None, temperature: float = None) -> str:
|
||||
"""同步调用LLM"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
|
||||
return loop.run_until_complete(self.async_call(prompt, system_prompt, max_tokens, temperature))
|
||||
|
||||
def parse_yaml_response(self, response: str) -> dict:
|
||||
"""解析YAML格式的响应"""
|
||||
try:
|
||||
# 提取```yaml和```之间的内容
|
||||
if '```yaml' in response:
|
||||
start = response.find('```yaml') + 7
|
||||
end = response.find('```', start)
|
||||
yaml_content = response[start:end].strip()
|
||||
elif '```' in response:
|
||||
start = response.find('```') + 3
|
||||
end = response.find('```', start)
|
||||
yaml_content = response[start:end].strip()
|
||||
else:
|
||||
yaml_content = response.strip()
|
||||
|
||||
return yaml.safe_load(yaml_content)
|
||||
except Exception as e:
|
||||
print(f"YAML解析失败: {e}")
|
||||
print(f"原始响应: {response}")
|
||||
return {}
|
||||
|
||||
async def close(self):
|
||||
"""关闭客户端"""
|
||||
await self.client.close()
|
||||
4
webapp/__init__.py
Normal file
4
webapp/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Web application package for the data analysis platform.
|
||||
"""
|
||||
242
webapp/api.py
Normal file
242
webapp/api.py
Normal file
@@ -0,0 +1,242 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
FastAPI application for the data analysis platform.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import FastAPI, File, Form, HTTPException, Query, UploadFile
|
||||
from fastapi.responses import FileResponse, HTMLResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from webapp.session_manager import SessionManager
|
||||
from webapp.storage import Storage, utcnow_iso
|
||||
from webapp.task_runner import TaskRunner
|
||||
|
||||
|
||||
BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
RUNTIME_DIR = os.path.join(BASE_DIR, "runtime")
|
||||
UPLOADS_DIR = os.path.join(RUNTIME_DIR, "uploads")
|
||||
OUTPUTS_DIR = os.path.join(BASE_DIR, "outputs")
|
||||
DB_PATH = os.path.join(RUNTIME_DIR, "analysis_platform.db")
|
||||
|
||||
os.makedirs(UPLOADS_DIR, exist_ok=True)
|
||||
os.makedirs(OUTPUTS_DIR, exist_ok=True)
|
||||
|
||||
storage = Storage(DB_PATH)
|
||||
session_manager = SessionManager(OUTPUTS_DIR)
|
||||
task_runner = TaskRunner(
|
||||
storage=storage,
|
||||
uploads_dir=UPLOADS_DIR,
|
||||
outputs_dir=OUTPUTS_DIR,
|
||||
session_manager=session_manager,
|
||||
max_workers=2,
|
||||
)
|
||||
|
||||
app = FastAPI(title="Data Analysis Platform API")
|
||||
STATIC_DIR = os.path.join(os.path.dirname(__file__), "static")
|
||||
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
||||
|
||||
|
||||
class CreateSessionRequest(BaseModel):
|
||||
user_id: str = Field(..., min_length=1)
|
||||
title: str = Field(..., min_length=1)
|
||||
query: str = Field(..., min_length=1)
|
||||
file_ids: List[str]
|
||||
template_file_id: Optional[str] = None
|
||||
|
||||
|
||||
class CreateTopicRequest(BaseModel):
|
||||
user_id: str = Field(..., min_length=1)
|
||||
query: str = Field(..., min_length=1)
|
||||
|
||||
|
||||
def sanitize_filename(filename: str) -> str:
|
||||
cleaned = re.sub(r"[^A-Za-z0-9._-]+", "_", filename).strip("._")
|
||||
return cleaned or "upload.bin"
|
||||
|
||||
|
||||
def ensure_session_access(session_id: str, user_id: str) -> dict:
|
||||
session = storage.get_session(session_id, user_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
return session
|
||||
|
||||
|
||||
def ensure_task_access(task_id: str, user_id: str) -> dict:
|
||||
task = storage.get_task(task_id, user_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
return task
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
def index():
|
||||
index_path = os.path.join(STATIC_DIR, "index.html")
|
||||
with open(index_path, "r", encoding="utf-8") as f:
|
||||
return HTMLResponse(f.read())
|
||||
|
||||
|
||||
@app.post("/files/upload")
|
||||
async def upload_files(
|
||||
user_id: str = Form(...),
|
||||
files: List[UploadFile] = File(...),
|
||||
):
|
||||
saved = []
|
||||
user_dir = os.path.join(UPLOADS_DIR, user_id)
|
||||
os.makedirs(user_dir, exist_ok=True)
|
||||
|
||||
for upload in files:
|
||||
safe_name = sanitize_filename(upload.filename or "upload.bin")
|
||||
stored_path = os.path.join(user_dir, f"{uuid.uuid4()}_{safe_name}")
|
||||
with open(stored_path, "wb") as f:
|
||||
while True:
|
||||
chunk = await upload.read(1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
saved.append(storage.create_uploaded_file(user_id, upload.filename or safe_name, stored_path))
|
||||
|
||||
return {"files": saved}
|
||||
|
||||
|
||||
@app.get("/files")
|
||||
def list_files(user_id: str = Query(...)):
|
||||
return {"files": storage.list_all_uploaded_files(user_id)}
|
||||
|
||||
|
||||
@app.post("/sessions")
|
||||
def create_session(request: CreateSessionRequest):
|
||||
if not storage.list_uploaded_files(request.file_ids, request.user_id):
|
||||
raise HTTPException(status_code=400, detail="No valid files found for session")
|
||||
|
||||
session = storage.create_session(
|
||||
user_id=request.user_id,
|
||||
title=request.title,
|
||||
uploaded_file_ids=request.file_ids,
|
||||
template_file_id=request.template_file_id,
|
||||
)
|
||||
task = storage.create_task(
|
||||
session_id=session["id"],
|
||||
user_id=request.user_id,
|
||||
query=request.query,
|
||||
uploaded_file_ids=request.file_ids,
|
||||
template_file_id=request.template_file_id,
|
||||
)
|
||||
task_runner.submit(task["id"], request.user_id)
|
||||
return {"session": session, "task": task}
|
||||
|
||||
|
||||
@app.get("/sessions")
|
||||
def list_sessions(user_id: str = Query(...)):
|
||||
return {"sessions": storage.list_sessions(user_id)}
|
||||
|
||||
|
||||
@app.get("/sessions/{session_id}")
|
||||
def get_session(session_id: str, user_id: str = Query(...)):
|
||||
session = ensure_session_access(session_id, user_id)
|
||||
tasks = storage.list_session_tasks(session_id, user_id)
|
||||
return {"session": session, "tasks": tasks}
|
||||
|
||||
|
||||
@app.post("/sessions/{session_id}/topics")
|
||||
def create_followup_topic(session_id: str, request: CreateTopicRequest):
|
||||
session = ensure_session_access(session_id, request.user_id)
|
||||
if session["status"] == "closed":
|
||||
raise HTTPException(status_code=400, detail="Session is closed")
|
||||
|
||||
task = storage.create_task(
|
||||
session_id=session_id,
|
||||
user_id=request.user_id,
|
||||
query=request.query,
|
||||
uploaded_file_ids=session["uploaded_file_ids"],
|
||||
template_file_id=session.get("template_file_id"),
|
||||
)
|
||||
task_runner.submit(task["id"], request.user_id)
|
||||
return {"session": session, "task": task}
|
||||
|
||||
|
||||
@app.post("/sessions/{session_id}/close")
|
||||
def close_session(session_id: str, user_id: str = Query(...)):
|
||||
session = ensure_session_access(session_id, user_id)
|
||||
storage.update_session(session_id, status="closed", closed_at=utcnow_iso())
|
||||
session_manager.close(session_id)
|
||||
return {"session": storage.get_session(session_id, user_id)}
|
||||
|
||||
|
||||
@app.get("/tasks")
|
||||
def list_tasks(user_id: str = Query(...)):
|
||||
return {"tasks": storage.list_tasks(user_id)}
|
||||
|
||||
|
||||
@app.get("/tasks/{task_id}")
|
||||
def get_task(task_id: str, user_id: str = Query(...)):
|
||||
task = ensure_task_access(task_id, user_id)
|
||||
return {"task": task}
|
||||
|
||||
|
||||
@app.get("/tasks/{task_id}/report")
|
||||
def get_task_report(task_id: str, user_id: str = Query(...)):
|
||||
task = ensure_task_access(task_id, user_id)
|
||||
report_path = task.get("report_file_path")
|
||||
if not report_path or not os.path.exists(report_path):
|
||||
raise HTTPException(status_code=404, detail="Report not available")
|
||||
return FileResponse(report_path, media_type="text/markdown", filename=os.path.basename(report_path))
|
||||
|
||||
|
||||
@app.get("/tasks/{task_id}/report/content")
|
||||
def get_task_report_content(task_id: str, user_id: str = Query(...)):
|
||||
task = ensure_task_access(task_id, user_id)
|
||||
report_path = task.get("report_file_path")
|
||||
if not report_path or not os.path.exists(report_path):
|
||||
raise HTTPException(status_code=404, detail="Report not available")
|
||||
with open(report_path, "r", encoding="utf-8") as f:
|
||||
return {"content": f.read(), "filename": os.path.basename(report_path)}
|
||||
|
||||
|
||||
@app.get("/tasks/{task_id}/artifacts")
|
||||
def list_task_artifacts(task_id: str, user_id: str = Query(...)):
|
||||
task = ensure_task_access(task_id, user_id)
|
||||
session_output_dir = task.get("session_output_dir")
|
||||
if not session_output_dir or not os.path.isdir(session_output_dir):
|
||||
return {"artifacts": []}
|
||||
|
||||
artifacts = []
|
||||
for name in sorted(os.listdir(session_output_dir)):
|
||||
path = os.path.join(session_output_dir, name)
|
||||
if not os.path.isfile(path):
|
||||
continue
|
||||
artifacts.append(
|
||||
{
|
||||
"name": name,
|
||||
"size": os.path.getsize(path),
|
||||
"is_image": name.lower().endswith((".png", ".jpg", ".jpeg", ".gif", ".webp")),
|
||||
"url": f"/tasks/{task_id}/artifacts/{name}?user_id={user_id}",
|
||||
}
|
||||
)
|
||||
return {"artifacts": artifacts}
|
||||
|
||||
|
||||
@app.get("/tasks/{task_id}/artifacts/{artifact_name}")
|
||||
def get_artifact(task_id: str, artifact_name: str, user_id: str = Query(...)):
|
||||
task = ensure_task_access(task_id, user_id)
|
||||
session_output_dir = task.get("session_output_dir")
|
||||
if not session_output_dir:
|
||||
raise HTTPException(status_code=404, detail="Artifact directory not available")
|
||||
|
||||
artifact_path = os.path.realpath(os.path.join(session_output_dir, artifact_name))
|
||||
session_root = os.path.realpath(session_output_dir)
|
||||
if artifact_path != session_root and not artifact_path.startswith(session_root + os.sep):
|
||||
raise HTTPException(status_code=400, detail="Invalid artifact path")
|
||||
if not os.path.exists(artifact_path):
|
||||
raise HTTPException(status_code=404, detail="Artifact not found")
|
||||
return FileResponse(artifact_path, filename=os.path.basename(artifact_path))
|
||||
66
webapp/session_manager.py
Normal file
66
webapp/session_manager.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
In-memory registry for long-lived analysis sessions.
|
||||
"""
|
||||
|
||||
import os
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional
|
||||
|
||||
from config.llm_config import LLMConfig
|
||||
from data_analysis_agent import DataAnalysisAgent
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuntimeSession:
|
||||
session_id: str
|
||||
user_id: str
|
||||
session_output_dir: str
|
||||
uploaded_files: list[str]
|
||||
template_path: Optional[str]
|
||||
agent: DataAnalysisAgent
|
||||
initialized: bool = False
|
||||
lock: threading.Lock = field(default_factory=threading.Lock)
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""Keeps session-scoped agents alive across follow-up topics."""
|
||||
|
||||
def __init__(self, outputs_dir: str):
|
||||
self.outputs_dir = os.path.abspath(outputs_dir)
|
||||
self._sessions: Dict[str, RuntimeSession] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def get_or_create(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
session_output_dir: str,
|
||||
uploaded_files: list[str],
|
||||
template_path: Optional[str],
|
||||
) -> RuntimeSession:
|
||||
with self._lock:
|
||||
runtime = self._sessions.get(session_id)
|
||||
if runtime is None:
|
||||
runtime = RuntimeSession(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
session_output_dir=session_output_dir,
|
||||
uploaded_files=uploaded_files,
|
||||
template_path=template_path,
|
||||
agent=DataAnalysisAgent(
|
||||
llm_config=LLMConfig(),
|
||||
output_dir=self.outputs_dir,
|
||||
max_rounds=20,
|
||||
force_max_rounds=False,
|
||||
),
|
||||
)
|
||||
self._sessions[session_id] = runtime
|
||||
return runtime
|
||||
|
||||
def close(self, session_id: str) -> None:
|
||||
with self._lock:
|
||||
runtime = self._sessions.pop(session_id, None)
|
||||
if runtime:
|
||||
runtime.agent.close_session()
|
||||
299
webapp/static/app.js
Normal file
299
webapp/static/app.js
Normal file
@@ -0,0 +1,299 @@
|
||||
const state = {
|
||||
userId: null,
|
||||
files: [],
|
||||
sessions: [],
|
||||
currentSessionId: null,
|
||||
currentTaskId: null,
|
||||
pollTimer: null,
|
||||
};
|
||||
|
||||
function ensureUserId() {
|
||||
const key = "vibe_data_ana_user_id";
|
||||
let userId = localStorage.getItem(key);
|
||||
if (!userId) {
|
||||
userId = `guest_${crypto.randomUUID()}`;
|
||||
localStorage.setItem(key, userId);
|
||||
}
|
||||
state.userId = userId;
|
||||
document.getElementById("user-id").textContent = userId;
|
||||
}
|
||||
|
||||
async function api(path, options = {}) {
|
||||
const response = await fetch(path, options);
|
||||
const text = await response.text();
|
||||
let data = {};
|
||||
try {
|
||||
data = text ? JSON.parse(text) : {};
|
||||
} catch {
|
||||
data = { raw: text };
|
||||
}
|
||||
if (!response.ok) {
|
||||
throw new Error(data.detail || data.error || response.statusText);
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
function setText(id, value) {
|
||||
document.getElementById(id).textContent = value || "";
|
||||
}
|
||||
|
||||
function renderFiles() {
|
||||
const fileList = document.getElementById("file-list");
|
||||
const picker = document.getElementById("session-file-picker");
|
||||
fileList.innerHTML = "";
|
||||
picker.innerHTML = "";
|
||||
|
||||
if (!state.files.length) {
|
||||
fileList.innerHTML = '<div class="empty">还没有上传文件。</div>';
|
||||
picker.innerHTML = '<div class="empty">先上传文件后才能创建会话。</div>';
|
||||
return;
|
||||
}
|
||||
|
||||
state.files.forEach((file) => {
|
||||
const item = document.createElement("div");
|
||||
item.className = "file-item";
|
||||
item.innerHTML = `<strong>${file.original_name}</strong><div class="hint">${file.id}</div>`;
|
||||
fileList.appendChild(item);
|
||||
|
||||
const label = document.createElement("label");
|
||||
label.className = "checkbox-item";
|
||||
label.innerHTML = `
|
||||
<input type="checkbox" value="${file.id}" />
|
||||
<span>${file.original_name}</span>
|
||||
`;
|
||||
picker.appendChild(label);
|
||||
});
|
||||
}
|
||||
|
||||
function statusBadge(status) {
|
||||
return `<span class="status ${status}">${status}</span>`;
|
||||
}
|
||||
|
||||
function renderSessions() {
|
||||
const container = document.getElementById("session-list");
|
||||
container.innerHTML = "";
|
||||
if (!state.sessions.length) {
|
||||
container.innerHTML = '<div class="empty">暂无会话。</div>';
|
||||
return;
|
||||
}
|
||||
|
||||
state.sessions.forEach((session) => {
|
||||
const card = document.createElement("button");
|
||||
card.type = "button";
|
||||
card.className = `session-card ${session.id === state.currentSessionId ? "active" : ""}`;
|
||||
card.innerHTML = `
|
||||
<div><strong>${session.title}</strong></div>
|
||||
<div class="hint">${session.id}</div>
|
||||
<div>${statusBadge(session.status)}</div>
|
||||
`;
|
||||
card.onclick = () => loadSessionDetail(session.id);
|
||||
container.appendChild(card);
|
||||
});
|
||||
}
|
||||
|
||||
function renderTasks(tasks) {
|
||||
const container = document.getElementById("task-list");
|
||||
container.innerHTML = "";
|
||||
if (!tasks.length) {
|
||||
container.innerHTML = '<div class="empty">当前会话还没有专题任务。</div>';
|
||||
return;
|
||||
}
|
||||
|
||||
tasks.forEach((task) => {
|
||||
const card = document.createElement("button");
|
||||
card.type = "button";
|
||||
card.className = `task-card ${task.id === state.currentTaskId ? "active" : ""}`;
|
||||
card.innerHTML = `
|
||||
<div><strong>${task.query}</strong></div>
|
||||
<div class="hint">${task.created_at}</div>
|
||||
<div>${statusBadge(task.status)}</div>
|
||||
`;
|
||||
card.onclick = () => loadTaskReport(task.id);
|
||||
container.appendChild(card);
|
||||
});
|
||||
}
|
||||
|
||||
async function refreshFiles() {
|
||||
const data = await api(`/files?user_id=${encodeURIComponent(state.userId)}`);
|
||||
state.files = data.files || [];
|
||||
renderFiles();
|
||||
}
|
||||
|
||||
async function refreshSessions(selectSessionId = null) {
|
||||
const data = await api(`/sessions?user_id=${encodeURIComponent(state.userId)}`);
|
||||
state.sessions = data.sessions || [];
|
||||
renderSessions();
|
||||
if (selectSessionId) {
|
||||
await loadSessionDetail(selectSessionId);
|
||||
} else if (state.currentSessionId) {
|
||||
const exists = state.sessions.some((session) => session.id === state.currentSessionId);
|
||||
if (exists) {
|
||||
await loadSessionDetail(state.currentSessionId, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async function loadSessionDetail(sessionId, renderReport = true) {
|
||||
const data = await api(`/sessions/${sessionId}?user_id=${encodeURIComponent(state.userId)}`);
|
||||
state.currentSessionId = sessionId;
|
||||
document.getElementById("detail-title").textContent = data.session.title;
|
||||
document.getElementById("detail-meta").textContent = `${data.session.id} · ${data.session.status}`;
|
||||
renderSessions();
|
||||
renderTasks(data.tasks || []);
|
||||
|
||||
const latestDoneTask = (data.tasks || []).slice().reverse().find((task) => task.status === "succeeded");
|
||||
if (renderReport && latestDoneTask) {
|
||||
await loadTaskReport(latestDoneTask.id);
|
||||
} else if (!latestDoneTask) {
|
||||
setText("report-title", "暂无已完成专题");
|
||||
setText("report-content", "当前会话还没有可展示的报告。");
|
||||
document.getElementById("artifact-gallery").innerHTML = "";
|
||||
}
|
||||
}
|
||||
|
||||
async function loadTaskReport(taskId) {
|
||||
state.currentTaskId = taskId;
|
||||
renderSessions();
|
||||
const taskData = await api(`/tasks/${taskId}?user_id=${encodeURIComponent(state.userId)}`);
|
||||
setText("report-title", taskData.task.query);
|
||||
|
||||
if (taskData.task.status !== "succeeded") {
|
||||
setText("report-content", `当前任务状态为 ${taskData.task.status}。\n错误信息:${taskData.task.error_message || "暂无"}`);
|
||||
document.getElementById("artifact-gallery").innerHTML = "";
|
||||
return;
|
||||
}
|
||||
|
||||
const reportData = await api(`/tasks/${taskId}/report/content?user_id=${encodeURIComponent(state.userId)}`);
|
||||
setText("report-content", reportData.content || "");
|
||||
|
||||
const artifactData = await api(`/tasks/${taskId}/artifacts?user_id=${encodeURIComponent(state.userId)}`);
|
||||
renderArtifacts(artifactData.artifacts || []);
|
||||
}
|
||||
|
||||
function renderArtifacts(artifacts) {
|
||||
const gallery = document.getElementById("artifact-gallery");
|
||||
gallery.innerHTML = "";
|
||||
const images = artifacts.filter((item) => item.is_image);
|
||||
if (!images.length) {
|
||||
gallery.innerHTML = '<div class="empty">当前任务没有图片产物。</div>';
|
||||
return;
|
||||
}
|
||||
images.forEach((artifact) => {
|
||||
const card = document.createElement("div");
|
||||
card.className = "artifact-card";
|
||||
card.innerHTML = `
|
||||
<img src="${artifact.url}" alt="${artifact.name}" />
|
||||
<div><a href="${artifact.url}" target="_blank" rel="noreferrer">${artifact.name}</a></div>
|
||||
`;
|
||||
gallery.appendChild(card);
|
||||
});
|
||||
}
|
||||
|
||||
async function handleUpload(event) {
|
||||
event.preventDefault();
|
||||
const input = document.getElementById("upload-input");
|
||||
if (!input.files.length) {
|
||||
setText("upload-status", "请选择至少一个文件。");
|
||||
return;
|
||||
}
|
||||
|
||||
const formData = new FormData();
|
||||
formData.append("user_id", state.userId);
|
||||
Array.from(input.files).forEach((file) => formData.append("files", file));
|
||||
setText("upload-status", "上传中...");
|
||||
await api("/files/upload", { method: "POST", body: formData });
|
||||
setText("upload-status", "文件上传完成。");
|
||||
input.value = "";
|
||||
await refreshFiles();
|
||||
}
|
||||
|
||||
async function handleCreateSession(event) {
|
||||
event.preventDefault();
|
||||
const checked = Array.from(document.querySelectorAll("#session-file-picker input:checked"));
|
||||
const fileIds = checked.map((item) => item.value);
|
||||
if (!fileIds.length) {
|
||||
setText("session-status", "请至少选择一个文件。");
|
||||
return;
|
||||
}
|
||||
|
||||
const payload = {
|
||||
user_id: state.userId,
|
||||
title: document.getElementById("session-title").value.trim(),
|
||||
query: document.getElementById("session-query").value.trim(),
|
||||
file_ids: fileIds,
|
||||
};
|
||||
setText("session-status", "会话创建中...");
|
||||
const data = await api("/sessions", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(payload),
|
||||
});
|
||||
setText("session-status", "会话已创建,正在执行首个专题。");
|
||||
document.getElementById("session-form").reset();
|
||||
await refreshSessions(data.session.id);
|
||||
}
|
||||
|
||||
async function handleFollowup() {
|
||||
if (!state.currentSessionId) {
|
||||
setText("detail-meta", "请先选择一个会话。");
|
||||
return;
|
||||
}
|
||||
const query = document.getElementById("followup-query").value.trim();
|
||||
if (!query) {
|
||||
return;
|
||||
}
|
||||
await api(`/sessions/${state.currentSessionId}/topics`, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ user_id: state.userId, query }),
|
||||
});
|
||||
document.getElementById("followup-query").value = "";
|
||||
await refreshSessions(state.currentSessionId);
|
||||
}
|
||||
|
||||
async function handleCloseSession() {
|
||||
if (!state.currentSessionId) {
|
||||
return;
|
||||
}
|
||||
await api(`/sessions/${state.currentSessionId}/close?user_id=${encodeURIComponent(state.userId)}`, {
|
||||
method: "POST",
|
||||
});
|
||||
await refreshSessions(state.currentSessionId);
|
||||
}
|
||||
|
||||
function startPolling() {
|
||||
if (state.pollTimer) {
|
||||
clearInterval(state.pollTimer);
|
||||
}
|
||||
state.pollTimer = setInterval(() => {
|
||||
refreshSessions().catch((error) => console.error(error));
|
||||
}, 8000);
|
||||
}
|
||||
|
||||
async function bootstrap() {
|
||||
ensureUserId();
|
||||
document.getElementById("upload-form").addEventListener("submit", (event) => {
|
||||
handleUpload(event).catch((error) => setText("upload-status", error.message));
|
||||
});
|
||||
document.getElementById("session-form").addEventListener("submit", (event) => {
|
||||
handleCreateSession(event).catch((error) => setText("session-status", error.message));
|
||||
});
|
||||
document.getElementById("submit-followup").onclick = () => {
|
||||
handleFollowup().catch((error) => setText("detail-meta", error.message));
|
||||
};
|
||||
document.getElementById("close-session").onclick = () => {
|
||||
handleCloseSession().catch((error) => setText("detail-meta", error.message));
|
||||
};
|
||||
document.getElementById("refresh-sessions").onclick = () => {
|
||||
refreshSessions().catch((error) => console.error(error));
|
||||
};
|
||||
|
||||
await refreshFiles();
|
||||
await refreshSessions();
|
||||
startPolling();
|
||||
}
|
||||
|
||||
bootstrap().catch((error) => {
|
||||
console.error(error);
|
||||
setText("detail-meta", error.message);
|
||||
});
|
||||
88
webapp/static/index.html
Normal file
88
webapp/static/index.html
Normal file
@@ -0,0 +1,88 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Vibe Data Analysis</title>
|
||||
<link rel="stylesheet" href="/static/style.css" />
|
||||
</head>
|
||||
<body>
|
||||
<main class="app-shell">
|
||||
<section class="hero">
|
||||
<div>
|
||||
<p class="eyebrow">Vibe Data Analysis</p>
|
||||
<h1>在线数据分析会话</h1>
|
||||
<p class="subtle">
|
||||
上传文件,发起分析,会话结束后继续追问新的专题,不中断当前上下文。
|
||||
</p>
|
||||
</div>
|
||||
<div class="identity-card">
|
||||
<span>当前访客标识</span>
|
||||
<code id="user-id"></code>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<section class="panel-grid">
|
||||
<section class="panel">
|
||||
<h2>1. 上传文件</h2>
|
||||
<form id="upload-form" class="stack-form">
|
||||
<input id="upload-input" type="file" multiple />
|
||||
<button type="submit">上传并登记</button>
|
||||
</form>
|
||||
<div id="upload-status" class="hint"></div>
|
||||
<div id="file-list" class="file-list"></div>
|
||||
</section>
|
||||
|
||||
<section class="panel">
|
||||
<h2>2. 新建分析会话</h2>
|
||||
<form id="session-form" class="stack-form">
|
||||
<input id="session-title" type="text" placeholder="会话标题,例如:工单健康度" required />
|
||||
<textarea id="session-query" rows="5" placeholder="输入首个分析专题,例如:请先整体评估工单健康度,并指出最需要关注的问题。" required></textarea>
|
||||
<div id="session-file-picker" class="checkbox-list"></div>
|
||||
<button type="submit">创建会话并开始分析</button>
|
||||
</form>
|
||||
<div id="session-status" class="hint"></div>
|
||||
</section>
|
||||
</section>
|
||||
|
||||
<section class="layout">
|
||||
<aside class="sidebar panel">
|
||||
<div class="sidebar-header">
|
||||
<h2>会话列表</h2>
|
||||
<button id="refresh-sessions" type="button">刷新</button>
|
||||
</div>
|
||||
<div id="session-list" class="session-list"></div>
|
||||
</aside>
|
||||
|
||||
<section class="content panel">
|
||||
<div class="content-header">
|
||||
<div>
|
||||
<h2 id="detail-title">未选择会话</h2>
|
||||
<p id="detail-meta" class="hint">选择左侧会话查看分析结果与后续专题。</p>
|
||||
</div>
|
||||
<button id="close-session" type="button" class="ghost">结束当前会话</button>
|
||||
</div>
|
||||
|
||||
<div class="followup-box">
|
||||
<textarea id="followup-query" rows="4" placeholder="如果还有新的专题想继续分析,在这里输入。"></textarea>
|
||||
<button id="submit-followup" type="button">继续分析该专题</button>
|
||||
</div>
|
||||
|
||||
<div class="tasks-area">
|
||||
<div class="tasks-column">
|
||||
<h3>专题任务</h3>
|
||||
<div id="task-list" class="task-list"></div>
|
||||
</div>
|
||||
<div class="report-column">
|
||||
<h3>报告展示</h3>
|
||||
<div id="report-title" class="report-title">暂无报告</div>
|
||||
<pre id="report-content" class="report-content">选择一个已完成任务查看报告。</pre>
|
||||
<div id="artifact-gallery" class="artifact-gallery"></div>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
</section>
|
||||
</main>
|
||||
<script src="/static/app.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
278
webapp/static/style.css
Normal file
278
webapp/static/style.css
Normal file
@@ -0,0 +1,278 @@
|
||||
:root {
|
||||
--bg: #f4efe7;
|
||||
--panel: #fffaf2;
|
||||
--line: #d8cdbd;
|
||||
--text: #1f1a17;
|
||||
--muted: #6e645a;
|
||||
--accent: #b04a2f;
|
||||
--accent-soft: #f2d5c4;
|
||||
--success: #2f7d62;
|
||||
--warning: #aa6a1f;
|
||||
}
|
||||
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
body {
|
||||
margin: 0;
|
||||
font-family: "IBM Plex Sans", "Noto Sans SC", sans-serif;
|
||||
color: var(--text);
|
||||
background:
|
||||
radial-gradient(circle at top left, #fff7ec 0, transparent 28rem),
|
||||
linear-gradient(180deg, #efe5d7 0%, var(--bg) 100%);
|
||||
}
|
||||
|
||||
.app-shell {
|
||||
max-width: 1440px;
|
||||
margin: 0 auto;
|
||||
padding: 24px;
|
||||
}
|
||||
|
||||
.hero,
|
||||
.panel,
|
||||
.content,
|
||||
.sidebar {
|
||||
border: 1px solid var(--line);
|
||||
background: rgba(255, 250, 242, 0.96);
|
||||
backdrop-filter: blur(10px);
|
||||
border-radius: 20px;
|
||||
box-shadow: 0 12px 40px rgba(93, 67, 39, 0.08);
|
||||
}
|
||||
|
||||
.hero {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
gap: 24px;
|
||||
padding: 24px 28px;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
.eyebrow {
|
||||
margin: 0 0 10px;
|
||||
color: var(--accent);
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.12em;
|
||||
font-size: 12px;
|
||||
}
|
||||
|
||||
h1,
|
||||
h2,
|
||||
h3,
|
||||
p {
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
.subtle,
|
||||
.hint {
|
||||
color: var(--muted);
|
||||
}
|
||||
|
||||
.identity-card {
|
||||
min-width: 220px;
|
||||
padding: 16px;
|
||||
border-radius: 16px;
|
||||
background: linear-gradient(135deg, var(--accent-soft), #f8e7dc);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.panel-grid,
|
||||
.layout,
|
||||
.tasks-area {
|
||||
display: grid;
|
||||
gap: 20px;
|
||||
}
|
||||
|
||||
.panel-grid {
|
||||
grid-template-columns: repeat(2, minmax(0, 1fr));
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
.layout {
|
||||
grid-template-columns: 340px minmax(0, 1fr);
|
||||
}
|
||||
|
||||
.tasks-area {
|
||||
grid-template-columns: 320px minmax(0, 1fr);
|
||||
}
|
||||
|
||||
.panel,
|
||||
.content,
|
||||
.sidebar {
|
||||
padding: 20px;
|
||||
}
|
||||
|
||||
.stack-form {
|
||||
display: grid;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
input,
|
||||
textarea,
|
||||
button {
|
||||
font: inherit;
|
||||
}
|
||||
|
||||
input,
|
||||
textarea {
|
||||
width: 100%;
|
||||
padding: 12px 14px;
|
||||
border-radius: 12px;
|
||||
border: 1px solid var(--line);
|
||||
background: #fffdf9;
|
||||
}
|
||||
|
||||
button {
|
||||
border: 0;
|
||||
border-radius: 999px;
|
||||
padding: 12px 18px;
|
||||
background: var(--accent);
|
||||
color: #fff;
|
||||
cursor: pointer;
|
||||
transition: transform 120ms ease, opacity 120ms ease;
|
||||
}
|
||||
|
||||
button:hover {
|
||||
transform: translateY(-1px);
|
||||
opacity: 0.95;
|
||||
}
|
||||
|
||||
button.ghost {
|
||||
background: transparent;
|
||||
color: var(--accent);
|
||||
border: 1px solid var(--accent-soft);
|
||||
}
|
||||
|
||||
.file-list,
|
||||
.checkbox-list,
|
||||
.session-list,
|
||||
.task-list,
|
||||
.artifact-gallery {
|
||||
display: grid;
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
.file-item,
|
||||
.session-card,
|
||||
.task-card,
|
||||
.artifact-card {
|
||||
padding: 12px 14px;
|
||||
border-radius: 14px;
|
||||
border: 1px solid var(--line);
|
||||
background: #fffdf8;
|
||||
}
|
||||
|
||||
.checkbox-item {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
padding: 10px 12px;
|
||||
border-radius: 12px;
|
||||
border: 1px solid var(--line);
|
||||
background: #fffdf8;
|
||||
}
|
||||
|
||||
.sidebar-header,
|
||||
.content-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: start;
|
||||
gap: 16px;
|
||||
}
|
||||
|
||||
.followup-box {
|
||||
margin: 20px 0;
|
||||
display: grid;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.session-card.active,
|
||||
.task-card.active {
|
||||
border-color: var(--accent);
|
||||
background: #fff3eb;
|
||||
}
|
||||
|
||||
.status {
|
||||
display: inline-flex;
|
||||
padding: 4px 10px;
|
||||
border-radius: 999px;
|
||||
font-size: 12px;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.status.queued {
|
||||
background: #f1ead7;
|
||||
color: #7b6114;
|
||||
}
|
||||
|
||||
.status.running {
|
||||
background: #e6edf9;
|
||||
color: #1e5dab;
|
||||
}
|
||||
|
||||
.status.succeeded,
|
||||
.status.open {
|
||||
background: #dff1ea;
|
||||
color: var(--success);
|
||||
}
|
||||
|
||||
.status.failed,
|
||||
.status.closed {
|
||||
background: #f8dfda;
|
||||
color: #a03723;
|
||||
}
|
||||
|
||||
.report-title {
|
||||
margin-bottom: 10px;
|
||||
color: var(--muted);
|
||||
}
|
||||
|
||||
.report-content {
|
||||
min-height: 360px;
|
||||
max-height: 720px;
|
||||
overflow: auto;
|
||||
padding: 16px;
|
||||
border-radius: 16px;
|
||||
background: #1d1a18;
|
||||
color: #f8f2ea;
|
||||
line-height: 1.6;
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
|
||||
.artifact-gallery {
|
||||
margin-top: 16px;
|
||||
grid-template-columns: repeat(auto-fill, minmax(220px, 1fr));
|
||||
}
|
||||
|
||||
.artifact-card img {
|
||||
width: 100%;
|
||||
height: 180px;
|
||||
object-fit: cover;
|
||||
border-radius: 12px;
|
||||
background: #ede3d7;
|
||||
}
|
||||
|
||||
.artifact-card a {
|
||||
color: var(--accent);
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
.empty {
|
||||
color: var(--muted);
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
@media (max-width: 1024px) {
|
||||
.panel-grid,
|
||||
.layout,
|
||||
.tasks-area,
|
||||
.hero {
|
||||
grid-template-columns: 1fr;
|
||||
}
|
||||
|
||||
.hero {
|
||||
flex-direction: column;
|
||||
}
|
||||
}
|
||||
311
webapp/storage.py
Normal file
311
webapp/storage.py
Normal file
@@ -0,0 +1,311 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
SQLite-backed storage for uploaded files and analysis tasks.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import threading
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
|
||||
def utcnow_iso() -> str:
|
||||
return datetime.utcnow().replace(microsecond=0).isoformat() + "Z"
|
||||
|
||||
|
||||
class Storage:
|
||||
"""Simple SQLite storage with thread-safe write operations."""
|
||||
|
||||
def __init__(self, db_path: str):
|
||||
self.db_path = os.path.abspath(db_path)
|
||||
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
|
||||
self._write_lock = threading.Lock()
|
||||
self.init_db()
|
||||
|
||||
@contextmanager
|
||||
def _connect(self):
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
yield conn
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def init_db(self) -> None:
|
||||
with self._connect() as conn:
|
||||
conn.executescript(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS uploaded_files (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
original_name TEXT NOT NULL,
|
||||
stored_path TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS analysis_sessions (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
uploaded_file_ids TEXT NOT NULL,
|
||||
template_file_id TEXT,
|
||||
session_output_dir TEXT,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL,
|
||||
closed_at TEXT
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS analysis_tasks (
|
||||
id TEXT PRIMARY KEY,
|
||||
session_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
query TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
uploaded_file_ids TEXT NOT NULL,
|
||||
template_file_id TEXT,
|
||||
session_output_dir TEXT,
|
||||
report_file_path TEXT,
|
||||
error_message TEXT,
|
||||
created_at TEXT NOT NULL,
|
||||
started_at TEXT,
|
||||
finished_at TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
def create_uploaded_file(
|
||||
self, user_id: str, original_name: str, stored_path: str
|
||||
) -> Dict[str, Any]:
|
||||
record = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"user_id": user_id,
|
||||
"original_name": original_name,
|
||||
"stored_path": os.path.abspath(stored_path),
|
||||
"created_at": utcnow_iso(),
|
||||
}
|
||||
with self._write_lock, self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO uploaded_files (id, user_id, original_name, stored_path, created_at)
|
||||
VALUES (:id, :user_id, :original_name, :stored_path, :created_at)
|
||||
""",
|
||||
record,
|
||||
)
|
||||
return record
|
||||
|
||||
def get_uploaded_file(self, file_id: str, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
with self._connect() as conn:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT * FROM uploaded_files WHERE id = ? AND user_id = ?
|
||||
""",
|
||||
(file_id, user_id),
|
||||
).fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
def list_uploaded_files(
|
||||
self, file_ids: Iterable[str], user_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
file_ids = list(file_ids)
|
||||
if not file_ids:
|
||||
return []
|
||||
placeholders = ",".join("?" for _ in file_ids)
|
||||
params = [*file_ids, user_id]
|
||||
with self._connect() as conn:
|
||||
rows = conn.execute(
|
||||
f"""
|
||||
SELECT * FROM uploaded_files
|
||||
WHERE id IN ({placeholders}) AND user_id = ?
|
||||
ORDER BY created_at ASC
|
||||
""",
|
||||
params,
|
||||
).fetchall()
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
def list_all_uploaded_files(self, user_id: str) -> List[Dict[str, Any]]:
|
||||
with self._connect() as conn:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT * FROM uploaded_files
|
||||
WHERE user_id = ?
|
||||
ORDER BY created_at DESC
|
||||
""",
|
||||
(user_id,),
|
||||
).fetchall()
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
def create_task(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
query: str,
|
||||
uploaded_file_ids: List[str],
|
||||
template_file_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
record = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"session_id": session_id,
|
||||
"user_id": user_id,
|
||||
"query": query,
|
||||
"status": "queued",
|
||||
"uploaded_file_ids": json.dumps(uploaded_file_ids, ensure_ascii=False),
|
||||
"template_file_id": template_file_id,
|
||||
"session_output_dir": None,
|
||||
"report_file_path": None,
|
||||
"error_message": None,
|
||||
"created_at": utcnow_iso(),
|
||||
"started_at": None,
|
||||
"finished_at": None,
|
||||
}
|
||||
with self._write_lock, self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO analysis_tasks (
|
||||
id, session_id, user_id, query, status, uploaded_file_ids, template_file_id,
|
||||
session_output_dir, report_file_path, error_message,
|
||||
created_at, started_at, finished_at
|
||||
)
|
||||
VALUES (
|
||||
:id, :session_id, :user_id, :query, :status, :uploaded_file_ids, :template_file_id,
|
||||
:session_output_dir, :report_file_path, :error_message,
|
||||
:created_at, :started_at, :finished_at
|
||||
)
|
||||
""",
|
||||
record,
|
||||
)
|
||||
return self.get_task(record["id"], user_id)
|
||||
|
||||
def get_task(self, task_id: str, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
with self._connect() as conn:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT * FROM analysis_tasks WHERE id = ? AND user_id = ?
|
||||
""",
|
||||
(task_id, user_id),
|
||||
).fetchone()
|
||||
return self._normalize_task(row) if row else None
|
||||
|
||||
def list_tasks(self, user_id: str) -> List[Dict[str, Any]]:
|
||||
with self._connect() as conn:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT * FROM analysis_tasks
|
||||
WHERE user_id = ?
|
||||
ORDER BY created_at DESC
|
||||
""",
|
||||
(user_id,),
|
||||
).fetchall()
|
||||
return [self._normalize_task(row) for row in rows]
|
||||
|
||||
def list_session_tasks(self, session_id: str, user_id: str) -> List[Dict[str, Any]]:
|
||||
with self._connect() as conn:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT * FROM analysis_tasks
|
||||
WHERE session_id = ? AND user_id = ?
|
||||
ORDER BY created_at ASC
|
||||
""",
|
||||
(session_id, user_id),
|
||||
).fetchall()
|
||||
return [self._normalize_task(row) for row in rows]
|
||||
|
||||
def create_session(
|
||||
self,
|
||||
user_id: str,
|
||||
title: str,
|
||||
uploaded_file_ids: List[str],
|
||||
template_file_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
now = utcnow_iso()
|
||||
record = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"user_id": user_id,
|
||||
"title": title,
|
||||
"status": "open",
|
||||
"uploaded_file_ids": json.dumps(uploaded_file_ids, ensure_ascii=False),
|
||||
"template_file_id": template_file_id,
|
||||
"session_output_dir": None,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"closed_at": None,
|
||||
}
|
||||
with self._write_lock, self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO analysis_sessions (
|
||||
id, user_id, title, status, uploaded_file_ids, template_file_id,
|
||||
session_output_dir, created_at, updated_at, closed_at
|
||||
)
|
||||
VALUES (
|
||||
:id, :user_id, :title, :status, :uploaded_file_ids, :template_file_id,
|
||||
:session_output_dir, :created_at, :updated_at, :closed_at
|
||||
)
|
||||
""",
|
||||
record,
|
||||
)
|
||||
return self.get_session(record["id"], user_id)
|
||||
|
||||
def get_session(self, session_id: str, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
with self._connect() as conn:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT * FROM analysis_sessions WHERE id = ? AND user_id = ?
|
||||
""",
|
||||
(session_id, user_id),
|
||||
).fetchone()
|
||||
return self._normalize_session(row) if row else None
|
||||
|
||||
def list_sessions(self, user_id: str) -> List[Dict[str, Any]]:
|
||||
with self._connect() as conn:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT * FROM analysis_sessions
|
||||
WHERE user_id = ?
|
||||
ORDER BY updated_at DESC
|
||||
""",
|
||||
(user_id,),
|
||||
).fetchall()
|
||||
return [self._normalize_session(row) for row in rows]
|
||||
|
||||
def update_session(self, session_id: str, **fields: Any) -> None:
|
||||
if not fields:
|
||||
return
|
||||
fields["updated_at"] = utcnow_iso()
|
||||
assignments = ", ".join(f"{key} = :{key}" for key in fields.keys())
|
||||
payload = dict(fields)
|
||||
payload["id"] = session_id
|
||||
with self._write_lock, self._connect() as conn:
|
||||
conn.execute(
|
||||
f"UPDATE analysis_sessions SET {assignments} WHERE id = :id",
|
||||
payload,
|
||||
)
|
||||
|
||||
def update_task(self, task_id: str, **fields: Any) -> None:
|
||||
if not fields:
|
||||
return
|
||||
assignments = ", ".join(f"{key} = :{key}" for key in fields.keys())
|
||||
payload = dict(fields)
|
||||
payload["id"] = task_id
|
||||
with self._write_lock, self._connect() as conn:
|
||||
conn.execute(
|
||||
f"UPDATE analysis_tasks SET {assignments} WHERE id = :id",
|
||||
payload,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_task(row: sqlite3.Row) -> Dict[str, Any]:
|
||||
task = dict(row)
|
||||
task["uploaded_file_ids"] = json.loads(task["uploaded_file_ids"])
|
||||
return task
|
||||
|
||||
@staticmethod
|
||||
def _normalize_session(row: sqlite3.Row) -> Dict[str, Any]:
|
||||
session = dict(row)
|
||||
session["uploaded_file_ids"] = json.loads(session["uploaded_file_ids"])
|
||||
return session
|
||||
147
webapp/task_runner.py
Normal file
147
webapp/task_runner.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Background task runner for analysis jobs.
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import redirect_stderr, redirect_stdout
|
||||
from typing import Optional
|
||||
|
||||
from utils.create_session_dir import create_session_output_dir
|
||||
from webapp.session_manager import SessionManager
|
||||
from webapp.storage import Storage, utcnow_iso
|
||||
|
||||
|
||||
class TaskRunner:
|
||||
"""Runs analysis tasks in background worker threads."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage: Storage,
|
||||
uploads_dir: str,
|
||||
outputs_dir: str,
|
||||
session_manager: SessionManager,
|
||||
max_workers: int = 2,
|
||||
):
|
||||
self.storage = storage
|
||||
self.uploads_dir = os.path.abspath(uploads_dir)
|
||||
self.outputs_dir = os.path.abspath(outputs_dir)
|
||||
self.session_manager = session_manager
|
||||
self._executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self._lock = threading.Lock()
|
||||
self._submitted = set()
|
||||
|
||||
def submit(self, task_id: str, user_id: str) -> None:
|
||||
with self._lock:
|
||||
if task_id in self._submitted:
|
||||
return
|
||||
self._submitted.add(task_id)
|
||||
self._executor.submit(self._run_task, task_id, user_id)
|
||||
|
||||
def _run_task(self, task_id: str, user_id: str) -> None:
|
||||
try:
|
||||
task = self.storage.get_task(task_id, user_id)
|
||||
if not task:
|
||||
return
|
||||
session = self.storage.get_session(task["session_id"], user_id)
|
||||
if not session:
|
||||
return
|
||||
|
||||
uploaded_files = self.storage.list_uploaded_files(
|
||||
task["uploaded_file_ids"], user_id
|
||||
)
|
||||
data_files = [item["stored_path"] for item in uploaded_files]
|
||||
template_path = self._resolve_template_path(task, user_id)
|
||||
session_output_dir = session.get("session_output_dir")
|
||||
if not session_output_dir:
|
||||
session_output_dir = create_session_output_dir(
|
||||
self.outputs_dir, session["title"]
|
||||
)
|
||||
self.storage.update_session(
|
||||
session["id"],
|
||||
session_output_dir=session_output_dir,
|
||||
)
|
||||
session = self.storage.get_session(task["session_id"], user_id)
|
||||
|
||||
runtime = self.session_manager.get_or_create(
|
||||
session_id=session["id"],
|
||||
user_id=user_id,
|
||||
session_output_dir=session_output_dir,
|
||||
uploaded_files=data_files,
|
||||
template_path=template_path,
|
||||
)
|
||||
|
||||
self.storage.update_task(
|
||||
task_id,
|
||||
status="running",
|
||||
session_output_dir=session_output_dir,
|
||||
started_at=utcnow_iso(),
|
||||
error_message=None,
|
||||
)
|
||||
self.storage.update_session(session["id"], status="running")
|
||||
|
||||
log_path = os.path.join(session_output_dir, "task.log")
|
||||
with runtime.lock:
|
||||
with open(log_path, "a", encoding="utf-8") as log_file:
|
||||
log_file.write(
|
||||
f"[{utcnow_iso()}] task started for session {session['id']}\n"
|
||||
)
|
||||
try:
|
||||
with redirect_stdout(log_file), redirect_stderr(log_file):
|
||||
result = runtime.agent.analyze(
|
||||
user_input=task["query"],
|
||||
files=data_files,
|
||||
template_path=template_path,
|
||||
session_output_dir=session_output_dir,
|
||||
reset_context=not runtime.initialized,
|
||||
keep_session_open=True,
|
||||
)
|
||||
runtime.initialized = True
|
||||
except Exception as exc:
|
||||
self.storage.update_task(
|
||||
task_id,
|
||||
status="failed",
|
||||
error_message=str(exc),
|
||||
finished_at=utcnow_iso(),
|
||||
report_file_path=None,
|
||||
)
|
||||
self.storage.update_session(session["id"], status="open")
|
||||
log_file.write(f"[{utcnow_iso()}] task failed: {exc}\n")
|
||||
return
|
||||
|
||||
report_file_path = self._persist_task_report(
|
||||
task_id, session_output_dir, result.get("report_file_path")
|
||||
)
|
||||
|
||||
self.storage.update_task(
|
||||
task_id,
|
||||
status="succeeded",
|
||||
report_file_path=report_file_path,
|
||||
finished_at=utcnow_iso(),
|
||||
error_message=None,
|
||||
)
|
||||
self.storage.update_session(session["id"], status="open")
|
||||
finally:
|
||||
with self._lock:
|
||||
self._submitted.discard(task_id)
|
||||
|
||||
def _resolve_template_path(self, task: dict, user_id: str) -> Optional[str]:
|
||||
template_file_id = task.get("template_file_id")
|
||||
if not template_file_id:
|
||||
return None
|
||||
file_record = self.storage.get_uploaded_file(template_file_id, user_id)
|
||||
return file_record["stored_path"] if file_record else None
|
||||
|
||||
@staticmethod
|
||||
def _persist_task_report(
|
||||
task_id: str, session_output_dir: str, current_report_path: Optional[str]
|
||||
) -> Optional[str]:
|
||||
if not current_report_path or not os.path.exists(current_report_path):
|
||||
return current_report_path
|
||||
task_report_path = os.path.join(session_output_dir, f"report_{task_id}.md")
|
||||
if os.path.abspath(current_report_path) != os.path.abspath(task_report_path):
|
||||
shutil.copyfile(current_report_path, task_report_path)
|
||||
return task_report_path
|
||||
Reference in New Issue
Block a user