feat: implement initial structure and core components for data analysis agent
This commit is contained in:
8
.env.example
Normal file
8
.env.example
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
|
||||||
|
# 火山引擎配置
|
||||||
|
OPENAI_API_KEY=sk-c44i1hy64xgzwox6x08o4zug93frq6rgn84oqugf2pje1tg4
|
||||||
|
OPENAI_BASE_URL=https://api.xiaomimimo.com/v1/chat/completions
|
||||||
|
# 文本模型
|
||||||
|
OPENAI_MODEL=mimo-v2-flash
|
||||||
|
# OPENAI_MODEL=deepseek-r1-250528
|
||||||
|
|
||||||
173
.gitignore
vendored
Normal file
173
.gitignore
vendored
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
pip-wheel-metadata/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# Project specific
|
||||||
|
# Output files and generated reports
|
||||||
|
outputs/
|
||||||
|
*.png
|
||||||
|
*.jpg
|
||||||
|
*.jpeg
|
||||||
|
*.pdf
|
||||||
|
*.docx
|
||||||
|
*.xlsx
|
||||||
|
*.csv
|
||||||
|
!贵州茅台利润表.csv
|
||||||
|
|
||||||
|
# 允许assets目录下的图片文件(项目资源)
|
||||||
|
!assets/**/*.png
|
||||||
|
!assets/**/*.jpg
|
||||||
|
!assets/**/*.jpeg
|
||||||
|
|
||||||
|
# IDE and editor files
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
|
||||||
|
# OS specific files
|
||||||
|
.DS_Store
|
||||||
|
.DS_Store?
|
||||||
|
._*
|
||||||
|
.Spotlight-V100
|
||||||
|
.Trashes
|
||||||
|
ehthumbs.db
|
||||||
|
Thumbs.db
|
||||||
|
|
||||||
|
# API keys and configuration
|
||||||
|
config.ini
|
||||||
|
.env
|
||||||
|
secrets.json
|
||||||
|
api_keys.txt
|
||||||
|
|
||||||
|
# Temporary files
|
||||||
|
*.tmp
|
||||||
|
*.temp
|
||||||
|
*.log
|
||||||
21
LICENSE
Normal file
21
LICENSE
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2025 Data Analysis Agent Team
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
54
__init__.py
Normal file
54
__init__.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Data Analysis Agent Package
|
||||||
|
|
||||||
|
一个基于LLM的智能数据分析代理,专门为Jupyter Notebook环境设计。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .core.notebook_agent import NotebookAgent
|
||||||
|
from .config.llm_config import LLMConfig
|
||||||
|
from .utils.code_executor import CodeExecutor
|
||||||
|
|
||||||
|
__version__ = "1.0.0"
|
||||||
|
__author__ = "Data Analysis Agent Team"
|
||||||
|
|
||||||
|
# 主要导出类
|
||||||
|
__all__ = [
|
||||||
|
"NotebookAgent",
|
||||||
|
"LLMConfig",
|
||||||
|
"CodeExecutor",
|
||||||
|
]
|
||||||
|
|
||||||
|
# 便捷函数
|
||||||
|
def create_agent(config=None, output_dir="outputs", max_rounds=20, session_dir=None):
|
||||||
|
"""
|
||||||
|
创建一个数据分析智能体实例
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: LLM配置,如果为None则使用默认配置
|
||||||
|
output_dir: 输出目录
|
||||||
|
max_rounds: 最大分析轮数
|
||||||
|
session_dir: 指定会话目录(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
NotebookAgent: 智能体实例
|
||||||
|
"""
|
||||||
|
if config is None:
|
||||||
|
config = LLMConfig()
|
||||||
|
return NotebookAgent(config=config, output_dir=output_dir, max_rounds=max_rounds, session_dir=session_dir)
|
||||||
|
|
||||||
|
def quick_analysis(query, files=None, output_dir="outputs", max_rounds=10):
|
||||||
|
"""
|
||||||
|
快速数据分析函数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 分析需求(自然语言)
|
||||||
|
files: 数据文件路径列表
|
||||||
|
output_dir: 输出目录
|
||||||
|
max_rounds: 最大分析轮数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 分析结果
|
||||||
|
"""
|
||||||
|
agent = create_agent(output_dir=output_dir, max_rounds=max_rounds)
|
||||||
|
return agent.analyze(query, files)
|
||||||
8
config/__init__.py
Normal file
8
config/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
配置模块
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .llm_config import LLMConfig
|
||||||
|
|
||||||
|
__all__ = ['LLMConfig']
|
||||||
44
config/llm_config.py
Normal file
44
config/llm_config.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
配置管理模块
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Dict, Any
|
||||||
|
from dataclasses import dataclass, asdict
|
||||||
|
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LLMConfig:
|
||||||
|
"""LLM配置"""
|
||||||
|
|
||||||
|
provider: str = "openai" # openai, anthropic, etc.
|
||||||
|
api_key: str = os.environ.get("OPENAI_API_KEY", "sk-c44i1hy64xgzwox6x08o4zug93frq6rgn84oqugf2pje1tg4")
|
||||||
|
base_url: str = os.environ.get("OPENAI_BASE_URL", "https://api.xiaomimimo.com/v1")
|
||||||
|
model: str = os.environ.get("OPENAI_MODEL", "mimo-v2-flash")
|
||||||
|
temperature: float = 0.3
|
||||||
|
max_tokens: int = 131072
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""转换为字典"""
|
||||||
|
return asdict(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "LLMConfig":
|
||||||
|
"""从字典创建配置"""
|
||||||
|
return cls(**data)
|
||||||
|
|
||||||
|
def validate(self) -> bool:
|
||||||
|
"""验证配置有效性"""
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError("OPENAI_API_KEY is required")
|
||||||
|
if not self.base_url:
|
||||||
|
raise ValueError("OPENAI_BASE_URL is required")
|
||||||
|
if not self.model:
|
||||||
|
raise ValueError("OPENAI_MODEL is required")
|
||||||
|
return True
|
||||||
483
data_analysis_agent.py
Normal file
483
data_analysis_agent.py
Normal file
@@ -0,0 +1,483 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
简化的 Notebook 数据分析智能体
|
||||||
|
仅包含用户和助手两个角
|
||||||
|
2. 图片必须保存到指定的会话目录中,输出绝对路径,禁止使用plt.show()
|
||||||
|
3. 表格输出控制:超过15行只显示前5行和后5行
|
||||||
|
4. 强制使用SimHei字体:plt.rcParams['font.sans-serif'] = ['SimHei']
|
||||||
|
5. 输出格式严格使用YAML共享上下文的单轮对话模式
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import yaml
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
from utils.create_session_dir import create_session_output_dir
|
||||||
|
from utils.format_execution_result import format_execution_result
|
||||||
|
from utils.extract_code import extract_code_from_response
|
||||||
|
from utils.data_loader import load_and_profile_data
|
||||||
|
from utils.llm_helper import LLMHelper
|
||||||
|
from utils.code_executor import CodeExecutor
|
||||||
|
from config.llm_config import LLMConfig
|
||||||
|
from prompts import data_analysis_system_prompt, final_report_system_prompt
|
||||||
|
|
||||||
|
|
||||||
|
class DataAnalysisAgent:
|
||||||
|
"""
|
||||||
|
数据分析智能体
|
||||||
|
|
||||||
|
职责:
|
||||||
|
- 接收用户自然语言需求
|
||||||
|
- 生成Python分析代码
|
||||||
|
- 执行代码并收集结果
|
||||||
|
- 基于执行结果继续生成后续分析代码
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
llm_config: LLMConfig = None,
|
||||||
|
output_dir: str = "outputs",
|
||||||
|
max_rounds: int = 20,
|
||||||
|
force_max_rounds: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化智能体
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: LLM配置
|
||||||
|
output_dir: 输出目录
|
||||||
|
max_rounds: 最大对话轮数
|
||||||
|
force_max_rounds: 是否强制运行到最大轮数(忽略AI的完成信号)
|
||||||
|
"""
|
||||||
|
self.config = llm_config or LLMConfig()
|
||||||
|
self.llm = LLMHelper(self.config)
|
||||||
|
self.base_output_dir = output_dir
|
||||||
|
self.max_rounds = max_rounds
|
||||||
|
self.force_max_rounds = force_max_rounds
|
||||||
|
# 对话历史和上下文
|
||||||
|
self.conversation_history = []
|
||||||
|
self.analysis_results = []
|
||||||
|
self.current_round = 0
|
||||||
|
self.session_output_dir = None
|
||||||
|
self.executor = None
|
||||||
|
|
||||||
|
def _process_response(self, response: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
统一处理LLM响应,判断行动类型并执行相应操作
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: LLM的响应内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理结果字典
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
yaml_data = self.llm.parse_yaml_response(response)
|
||||||
|
action = yaml_data.get("action", "generate_code")
|
||||||
|
|
||||||
|
print(f"🎯 检测到动作: {action}")
|
||||||
|
|
||||||
|
if action == "analysis_complete":
|
||||||
|
return self._handle_analysis_complete(response, yaml_data)
|
||||||
|
elif action == "collect_figures":
|
||||||
|
return self._handle_collect_figures(response, yaml_data)
|
||||||
|
elif action == "generate_code":
|
||||||
|
return self._handle_generate_code(response, yaml_data)
|
||||||
|
else:
|
||||||
|
print(f"⚠️ 未知动作类型: {action},按generate_code处理")
|
||||||
|
return self._handle_generate_code(response, yaml_data)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ 解析响应失败: {str(e)},按generate_code处理")
|
||||||
|
return self._handle_generate_code(response, {})
|
||||||
|
|
||||||
|
def _handle_analysis_complete(
|
||||||
|
self, response: str, yaml_data: Dict[str, Any]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""处理分析完成动作"""
|
||||||
|
print("✅ 分析任务完成")
|
||||||
|
final_report = yaml_data.get("final_report", "分析完成,无最终报告")
|
||||||
|
return {
|
||||||
|
"action": "analysis_complete",
|
||||||
|
"final_report": final_report,
|
||||||
|
"response": response,
|
||||||
|
"continue": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _handle_collect_figures(
|
||||||
|
self, response: str, yaml_data: Dict[str, Any]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""处理图片收集动作"""
|
||||||
|
print("📊 开始收集图片")
|
||||||
|
figures_to_collect = yaml_data.get("figures_to_collect", [])
|
||||||
|
|
||||||
|
collected_figures = []
|
||||||
|
|
||||||
|
for figure_info in figures_to_collect:
|
||||||
|
figure_number = figure_info.get("figure_number", "未知")
|
||||||
|
# 确保figure_number不为None时才用于文件名
|
||||||
|
if figure_number != "未知":
|
||||||
|
default_filename = f"figure_{figure_number}.png"
|
||||||
|
else:
|
||||||
|
default_filename = "figure_unknown.png"
|
||||||
|
filename = figure_info.get("filename", default_filename)
|
||||||
|
file_path = figure_info.get("file_path", "") # 获取具体的文件路径
|
||||||
|
description = figure_info.get("description", "")
|
||||||
|
analysis = figure_info.get("analysis", "")
|
||||||
|
|
||||||
|
print(f"📈 收集图片 {figure_number}: {filename}")
|
||||||
|
print(f" 📂 路径: {file_path}")
|
||||||
|
print(f" 📝 描述: {description}")
|
||||||
|
print(f" 🔍 分析: {analysis}")
|
||||||
|
|
||||||
|
# 验证文件是否存在
|
||||||
|
if file_path and os.path.exists(file_path):
|
||||||
|
print(f" ✅ 文件存在: {file_path}")
|
||||||
|
elif file_path:
|
||||||
|
print(f" ⚠️ 文件不存在: {file_path}")
|
||||||
|
else:
|
||||||
|
print(f" ⚠️ 未提供文件路径")
|
||||||
|
|
||||||
|
# 记录图片信息
|
||||||
|
collected_figures.append(
|
||||||
|
{
|
||||||
|
"figure_number": figure_number,
|
||||||
|
"filename": filename,
|
||||||
|
"file_path": file_path,
|
||||||
|
"description": description,
|
||||||
|
"analysis": analysis,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"action": "collect_figures",
|
||||||
|
"collected_figures": collected_figures,
|
||||||
|
"response": response,
|
||||||
|
"continue": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _handle_generate_code(
|
||||||
|
self, response: str, yaml_data: Dict[str, Any]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""处理代码生成和执行动作"""
|
||||||
|
# 从YAML数据中获取代码(更准确)
|
||||||
|
code = yaml_data.get("code", "")
|
||||||
|
|
||||||
|
# 如果YAML中没有代码,尝试从响应中提取
|
||||||
|
if not code:
|
||||||
|
code = extract_code_from_response(response)
|
||||||
|
|
||||||
|
# 二次清洗:防止YAML中解析出的code包含markdown标记
|
||||||
|
if code:
|
||||||
|
code = code.strip()
|
||||||
|
if code.startswith("```"):
|
||||||
|
import re
|
||||||
|
# 去除开头的 ```python 或 ```
|
||||||
|
code = re.sub(r"^```[a-zA-Z]*\n", "", code)
|
||||||
|
# 去除结尾的 ```
|
||||||
|
code = re.sub(r"\n```$", "", code)
|
||||||
|
code = code.strip()
|
||||||
|
|
||||||
|
if code:
|
||||||
|
print(f"🔧 执行代码:\n{code}")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
# 执行代码
|
||||||
|
result = self.executor.execute_code(code)
|
||||||
|
|
||||||
|
# 格式化执行结果
|
||||||
|
feedback = format_execution_result(result)
|
||||||
|
print(f"📋 执行反馈:\n{feedback}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"action": "generate_code",
|
||||||
|
"code": code,
|
||||||
|
"result": result,
|
||||||
|
"feedback": feedback,
|
||||||
|
"response": response,
|
||||||
|
"continue": True,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# 如果没有代码,说明LLM响应格式有问题,需要重新生成
|
||||||
|
print("⚠️ 未从响应中提取到可执行代码,要求LLM重新生成")
|
||||||
|
return {
|
||||||
|
"action": "invalid_response",
|
||||||
|
"error": "响应中缺少可执行代码",
|
||||||
|
"response": response,
|
||||||
|
"continue": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
def analyze(self, user_input: str, files: List[str] = None) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
开始分析流程
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_input: 用户的自然语言需求
|
||||||
|
files: 数据文件路径列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
分析结果字典
|
||||||
|
"""
|
||||||
|
# 重置状态
|
||||||
|
self.conversation_history = []
|
||||||
|
self.analysis_results = []
|
||||||
|
self.current_round = 0
|
||||||
|
|
||||||
|
# 创建本次分析的专用输出目录
|
||||||
|
self.session_output_dir = create_session_output_dir(
|
||||||
|
self.base_output_dir, user_input
|
||||||
|
)
|
||||||
|
|
||||||
|
# 初始化代码执行器,使用会话目录
|
||||||
|
self.executor = CodeExecutor(self.session_output_dir)
|
||||||
|
|
||||||
|
# 设置会话目录变量到执行环境中
|
||||||
|
self.executor.set_variable("session_output_dir", self.session_output_dir)
|
||||||
|
|
||||||
|
# 设用工具生成数据画像
|
||||||
|
data_profile = ""
|
||||||
|
if files:
|
||||||
|
print("🔍 正在生成数据画像...")
|
||||||
|
data_profile = load_and_profile_data(files)
|
||||||
|
print("✅ 数据画像生成完毕")
|
||||||
|
|
||||||
|
# 构建初始prompt
|
||||||
|
initial_prompt = f"""用户需求: {user_input}"""
|
||||||
|
if files:
|
||||||
|
initial_prompt += f"\n数据文件: {', '.join(files)}"
|
||||||
|
|
||||||
|
if data_profile:
|
||||||
|
initial_prompt += f"\n\n{data_profile}\n\n请根据上述【数据画像】中的统计信息(如高频值、缺失率、数据范围)来制定分析策略。如果发现明显的高频问题或异常分布,请优先进行深度分析。"
|
||||||
|
|
||||||
|
print(f"🚀 开始数据分析任务")
|
||||||
|
print(f"📝 用户需求: {user_input}")
|
||||||
|
if files:
|
||||||
|
print(f"📁 数据文件: {', '.join(files)}")
|
||||||
|
print(f"📂 输出目录: {self.session_output_dir}")
|
||||||
|
print(f"🔢 最大轮数: {self.max_rounds}")
|
||||||
|
if self.force_max_rounds:
|
||||||
|
print(f"⚡ 强制模式: 将运行满 {self.max_rounds} 轮(忽略AI完成信号)")
|
||||||
|
print("=" * 60)
|
||||||
|
# 添加到对话历史
|
||||||
|
self.conversation_history.append({"role": "user", "content": initial_prompt})
|
||||||
|
|
||||||
|
while self.current_round < self.max_rounds:
|
||||||
|
self.current_round += 1
|
||||||
|
print(f"\n🔄 第 {self.current_round} 轮分析")
|
||||||
|
# 调用LLM生成响应
|
||||||
|
try: # 获取当前执行环境的变量信息
|
||||||
|
notebook_variables = self.executor.get_environment_info()
|
||||||
|
|
||||||
|
# 格式化系统提示词,填入动态的notebook变量信息
|
||||||
|
formatted_system_prompt = data_analysis_system_prompt.format(
|
||||||
|
notebook_variables=notebook_variables
|
||||||
|
)
|
||||||
|
|
||||||
|
response = self.llm.call(
|
||||||
|
prompt=self._build_conversation_prompt(),
|
||||||
|
system_prompt=formatted_system_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"🤖 助手响应:\n{response}")
|
||||||
|
|
||||||
|
# 使用统一的响应处理方法
|
||||||
|
process_result = self._process_response(response)
|
||||||
|
|
||||||
|
# 根据处理结果决定是否继续(仅在非强制模式下)
|
||||||
|
if not self.force_max_rounds and not process_result.get(
|
||||||
|
"continue", True
|
||||||
|
):
|
||||||
|
print(f"\n✅ 分析完成!")
|
||||||
|
break
|
||||||
|
|
||||||
|
# 添加到对话历史
|
||||||
|
self.conversation_history.append(
|
||||||
|
{"role": "assistant", "content": response}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 根据动作类型添加不同的反馈
|
||||||
|
if process_result["action"] == "generate_code":
|
||||||
|
feedback = process_result.get("feedback", "")
|
||||||
|
self.conversation_history.append(
|
||||||
|
{"role": "user", "content": f"代码执行反馈:\n{feedback}"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 记录分析结果
|
||||||
|
self.analysis_results.append(
|
||||||
|
{
|
||||||
|
"round": self.current_round,
|
||||||
|
"code": process_result.get("code", ""),
|
||||||
|
"result": process_result.get("result", {}),
|
||||||
|
"response": response,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif process_result["action"] == "collect_figures":
|
||||||
|
# 记录图片收集结果
|
||||||
|
collected_figures = process_result.get("collected_figures", [])
|
||||||
|
feedback = f"已收集 {len(collected_figures)} 个图片及其分析"
|
||||||
|
self.conversation_history.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"图片收集反馈:\n{feedback}\n请继续下一步分析。",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 记录到分析结果中
|
||||||
|
self.analysis_results.append(
|
||||||
|
{
|
||||||
|
"round": self.current_round,
|
||||||
|
"action": "collect_figures",
|
||||||
|
"collected_figures": collected_figures,
|
||||||
|
"response": response,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"LLM调用错误: {str(e)}"
|
||||||
|
print(f"❌ {error_msg}")
|
||||||
|
self.conversation_history.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"发生错误: {error_msg},请重新生成代码。",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# 生成最终总结
|
||||||
|
if self.current_round >= self.max_rounds:
|
||||||
|
print(f"\n⚠️ 已达到最大轮数 ({self.max_rounds}),分析结束")
|
||||||
|
|
||||||
|
return self._generate_final_report()
|
||||||
|
|
||||||
|
def _build_conversation_prompt(self) -> str:
|
||||||
|
"""构建对话提示词"""
|
||||||
|
prompt_parts = []
|
||||||
|
|
||||||
|
for msg in self.conversation_history:
|
||||||
|
role = msg["role"]
|
||||||
|
content = msg["content"]
|
||||||
|
if role == "user":
|
||||||
|
prompt_parts.append(f"用户: {content}")
|
||||||
|
else:
|
||||||
|
prompt_parts.append(f"助手: {content}")
|
||||||
|
|
||||||
|
return "\n\n".join(prompt_parts)
|
||||||
|
|
||||||
|
def _generate_final_report(self) -> Dict[str, Any]:
|
||||||
|
"""生成最终分析报告"""
|
||||||
|
# 收集所有生成的图片信息
|
||||||
|
all_figures = []
|
||||||
|
for result in self.analysis_results:
|
||||||
|
if result.get("action") == "collect_figures":
|
||||||
|
all_figures.extend(result.get("collected_figures", []))
|
||||||
|
|
||||||
|
print(f"\n📊 开始生成最终分析报告...")
|
||||||
|
print(f"📂 输出目录: {self.session_output_dir}")
|
||||||
|
print(f"🔢 总轮数: {self.current_round}")
|
||||||
|
print(f"📈 收集图片: {len(all_figures)} 个")
|
||||||
|
|
||||||
|
# 构建用于生成最终报告的提示词
|
||||||
|
final_report_prompt = self._build_final_report_prompt(all_figures)
|
||||||
|
|
||||||
|
try: # 调用LLM生成最终报告
|
||||||
|
response = self.llm.call(
|
||||||
|
prompt=final_report_prompt,
|
||||||
|
system_prompt="你将会接收到一个数据分析任务的最终报告请求,请根据提供的分析结果和图片信息生成完整的分析报告。",
|
||||||
|
max_tokens=16384, # 设置较大的token限制以容纳完整报告
|
||||||
|
)
|
||||||
|
|
||||||
|
# 解析响应,提取最终报告
|
||||||
|
try:
|
||||||
|
yaml_data = self.llm.parse_yaml_response(response)
|
||||||
|
if yaml_data.get("action") == "analysis_complete":
|
||||||
|
final_report_content = yaml_data.get("final_report", "报告生成失败")
|
||||||
|
else:
|
||||||
|
final_report_content = (
|
||||||
|
"LLM未返回analysis_complete动作,报告生成失败"
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
# 如果解析失败,直接使用响应内容
|
||||||
|
final_report_content = response
|
||||||
|
|
||||||
|
print("✅ 最终报告生成完成")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ 生成最终报告时出错: {str(e)}")
|
||||||
|
final_report_content = f"报告生成失败: {str(e)}"
|
||||||
|
|
||||||
|
# 保存最终报告到文件
|
||||||
|
report_file_path = os.path.join(self.session_output_dir, "最终分析报告.md")
|
||||||
|
try:
|
||||||
|
with open(report_file_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(final_report_content)
|
||||||
|
print(f"📄 最终报告已保存至: {report_file_path}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ 保存报告文件失败: {str(e)}")
|
||||||
|
|
||||||
|
# 返回完整的分析结果
|
||||||
|
return {
|
||||||
|
"session_output_dir": self.session_output_dir,
|
||||||
|
"total_rounds": self.current_round,
|
||||||
|
"analysis_results": self.analysis_results,
|
||||||
|
"collected_figures": all_figures,
|
||||||
|
"conversation_history": self.conversation_history,
|
||||||
|
"final_report": final_report_content,
|
||||||
|
"report_file_path": report_file_path,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _build_final_report_prompt(self, all_figures: List[Dict[str, Any]]) -> str:
|
||||||
|
"""构建用于生成最终报告的提示词"""
|
||||||
|
|
||||||
|
# 构建图片信息摘要,使用相对路径
|
||||||
|
figures_summary = ""
|
||||||
|
if all_figures:
|
||||||
|
figures_summary = "\n生成的图片及分析:\n"
|
||||||
|
for i, figure in enumerate(all_figures, 1):
|
||||||
|
filename = figure.get("filename", "未知文件名")
|
||||||
|
# 使用相对路径格式,适合在报告中引用
|
||||||
|
relative_path = f"./{filename}"
|
||||||
|
figures_summary += f"{i}. {filename}\n"
|
||||||
|
figures_summary += f" 相对路径: {relative_path}\n"
|
||||||
|
figures_summary += f" 描述: {figure.get('description', '无描述')}\n"
|
||||||
|
figures_summary += f" 分析: {figure.get('analysis', '无分析')}\n\n"
|
||||||
|
else:
|
||||||
|
figures_summary = "\n本次分析未生成图片。\n"
|
||||||
|
|
||||||
|
# 构建代码执行结果摘要(仅包含成功执行的代码块)
|
||||||
|
code_results_summary = ""
|
||||||
|
success_code_count = 0
|
||||||
|
for result in self.analysis_results:
|
||||||
|
if result.get("action") != "collect_figures" and result.get("code"):
|
||||||
|
exec_result = result.get("result", {})
|
||||||
|
if exec_result.get("success"):
|
||||||
|
success_code_count += 1
|
||||||
|
code_results_summary += f"代码块 {success_code_count}: 执行成功\n"
|
||||||
|
if exec_result.get("output"):
|
||||||
|
code_results_summary += (
|
||||||
|
f"输出: {exec_result.get('output')[:]}\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 使用 prompts.py 中的统一提示词模板,并添加相对路径使用说明
|
||||||
|
prompt = final_report_system_prompt.format(
|
||||||
|
current_round=self.current_round,
|
||||||
|
session_output_dir=self.session_output_dir,
|
||||||
|
figures_summary=figures_summary,
|
||||||
|
code_results_summary=code_results_summary,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 在提示词中明确要求使用相对路径
|
||||||
|
prompt += """
|
||||||
|
|
||||||
|
📁 **图片路径使用说明**:
|
||||||
|
报告和图片都在同一目录下,请在报告中使用相对路径引用图片:
|
||||||
|
- 格式:
|
||||||
|
- 示例:
|
||||||
|
- 这样可以确保报告在不同环境下都能正确显示图片
|
||||||
|
"""
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""重置智能体状态"""
|
||||||
|
self.conversation_history = []
|
||||||
|
self.analysis_results = []
|
||||||
|
self.current_round = 0
|
||||||
|
self.executor.reset_environment()
|
||||||
18
main.py
Normal file
18
main.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
from data_analysis_agent import DataAnalysisAgent
|
||||||
|
from config.llm_config import LLMConfig
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
llm_config = LLMConfig()
|
||||||
|
# 如果希望强制运行到最大轮数,设置 force_max_rounds=True
|
||||||
|
agent = DataAnalysisAgent(llm_config, force_max_rounds=False)
|
||||||
|
files = ["./UB IOV Support_TR.csv"]
|
||||||
|
report = agent.analyze(
|
||||||
|
user_input="基于所有有关远程控制的问题,以及涉及车控APP的运维工单的数据,输出若干个重要的统计指标,并绘制相关图表。总结一份,车控APP,及远程控制工单健康度报告,最后生成汇报给我。",
|
||||||
|
files=files,
|
||||||
|
)
|
||||||
|
print(report)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
286
prompts.py
Normal file
286
prompts.py
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
data_analysis_system_prompt = """你是一个专业的数据分析助手,运行在Jupyter Notebook环境中,能够根据用户需求生成和执行Python数据分析代码。
|
||||||
|
|
||||||
|
**重要指导原则**:
|
||||||
|
- 当需要执行Python代码(数据加载、分析、可视化)时,使用 `generate_code` 动作
|
||||||
|
- 当需要收集和分析已生成的图表时,使用 `collect_figures` 动作
|
||||||
|
- 当所有分析工作完成,需要输出最终报告时,使用 `analysis_complete` 动作
|
||||||
|
- 每次响应只能选择一种动作类型,不要混合使用
|
||||||
|
|
||||||
|
目前jupyter notebook环境下有以下变量:
|
||||||
|
{notebook_variables}
|
||||||
|
核心能力:
|
||||||
|
1. 接收用户的自然语言分析需求
|
||||||
|
2. 按步骤生成安全的Python分析代码
|
||||||
|
3. 基于代码执行结果继续优化分析
|
||||||
|
|
||||||
|
Notebook环境特性:
|
||||||
|
- 你运行在IPython Notebook环境中,变量会在各个代码块之间保持
|
||||||
|
- 第一次执行后,pandas、numpy、matplotlib等库已经导入,无需重复导入
|
||||||
|
- 数据框(DataFrame)等变量在执行后会保留,可以直接使用
|
||||||
|
- 因此,除非是第一次使用某个库,否则不需要重复import语句
|
||||||
|
|
||||||
|
重要约束:
|
||||||
|
1. 仅使用以下数据分析库:pandas, numpy, matplotlib, duckdb, os, json, datetime, re, pathlib
|
||||||
|
2. 图片必须保存到指定的会话目录中,输出绝对路径,禁止使用plt.show(),饼图的标签全部放在图例里面,用颜色区分。
|
||||||
|
4. 表格输出控制:超过15行只显示前5行和后5行
|
||||||
|
5. 中文字体设置:使用系统可用中文字体(macOS推荐:Hiragino Sans GB, Songti SC等)
|
||||||
|
6. 输出格式严格使用YAML
|
||||||
|
|
||||||
|
|
||||||
|
输出目录管理:
|
||||||
|
- 本次分析使用时间戳生成的专用目录,确保每次分析的输出文件隔离
|
||||||
|
- 会话目录格式:session_[时间戳],如 session_20240105_143052
|
||||||
|
- 图片保存路径格式:os.path.join(session_output_dir, '图片名称.png')
|
||||||
|
- 使用有意义的中文文件名:如'营业收入趋势.png', '利润分析对比.png'
|
||||||
|
- 每个图表保存后必须使用plt.close()释放内存
|
||||||
|
- 输出绝对路径:使用os.path.abspath()获取图片的完整路径
|
||||||
|
|
||||||
|
数据分析工作流程(必须严格按顺序执行):
|
||||||
|
|
||||||
|
**阶段1:数据探索(使用 generate_code 动作)**
|
||||||
|
- 首次数据加载时尝试多种编码:['utf-8', 'gbk', 'gb18030', 'gb2312', 'latin1']
|
||||||
|
- 特殊处理:如果读取失败,尝试指定分隔符 `sep=','` 和错误处理 `error_bad_lines=False`
|
||||||
|
- 使用df.head()查看前几行数据,检查数据是否正确读取
|
||||||
|
- 使用df.info()了解数据类型和缺失值情况
|
||||||
|
- 重点检查:如果数值列显示为NaN但应该有值,说明读取或解析有问题
|
||||||
|
- 使用df.dtypes查看每列的数据类型,确保日期列不是float64
|
||||||
|
- 打印所有列名:df.columns.tolist()
|
||||||
|
- 绝对不要假设列名,必须先查看实际的列名
|
||||||
|
|
||||||
|
**阶段2:数据清洗和检查(使用 generate_code 动作)**
|
||||||
|
- 日期列识别:查找包含'date', 'time', 'Date', 'Time'关键词的列
|
||||||
|
- 日期解析:尝试多种格式 ['%d/%m/%Y', '%Y-%m-%d', '%m/%d/%Y', '%Y/%m/%d', '%d-%m-%Y']
|
||||||
|
- 类型转换:使用pd.to_datetime()转换日期列,指定format参数和errors='coerce'
|
||||||
|
- 空值处理:检查哪些列应该有值但显示NaN,可能是数据读取问题
|
||||||
|
- 检查数据的时间范围和排序
|
||||||
|
- 数据质量检查:确认数值列是否正确,字符串列是否被错误识别
|
||||||
|
|
||||||
|
|
||||||
|
**阶段3:数据分析和可视化(使用 generate_code 动作)**
|
||||||
|
- 基于实际的列名进行计算
|
||||||
|
- 生成有意义的图表
|
||||||
|
- 图片保存到会话专用目录中
|
||||||
|
- 每生成一个图表后,必须打印绝对路径
|
||||||
|
|
||||||
|
|
||||||
|
**阶段4:深度挖掘与高级分析(使用 generate_code 动作)**
|
||||||
|
- **主动评估数据特征**:在执行前,先分析数据适合哪种高级挖掘:
|
||||||
|
- **时间序列数据**:必须进行趋势预测(使用sklearn/ARIMA/Prophet-like逻辑)和季节性分解。
|
||||||
|
- **多维数值数据**:必须进行聚类分析(K-Means/DBSCAN)以发现用户/产品分层。
|
||||||
|
- **分类/目标数据**:必须计算特征重要性(使用随机森林/相关性矩阵)以识别关键驱动因素。
|
||||||
|
- **异常检测**:使用Isolation Forest或统计方法识别高价值或高风险的离群点。
|
||||||
|
- **拒绝平庸**:不要为了做而做。如果数据量太小(<50行)或特征单一,请明确说明无法进行特定分析,并尝试挖掘其他角度(如分布偏度、帕累托分析)。
|
||||||
|
- **业务导向**:每个模型结果必须翻译成业务语言(例如:“聚类结果显示,A类用户是高价值且对价格不敏感的群体”)。
|
||||||
|
|
||||||
|
**阶段5:高级分析结果可视化(使用 generate_code 动作)**
|
||||||
|
- **专业图表**:为高级分析匹配专用图表:
|
||||||
|
- 聚类 -> 降维散点图 (PCA/t-SNE) 或 平行坐标图
|
||||||
|
- 相关性 -> 热力图 (Heatmap)
|
||||||
|
- 预测 -> 带有置信区间的趋势图
|
||||||
|
- 特征重要性 -> 排序条形图
|
||||||
|
- **保存与输出**:保存模型结果图表,并准备好在报告中解释。
|
||||||
|
|
||||||
|
**阶段6:图片收集和分析(使用 collect_figures 动作)**
|
||||||
|
- 当已生成多个图表后,使用 collect_figures 动作
|
||||||
|
- 收集所有已生成的图片路径和信息
|
||||||
|
- 对每个图片进行详细的分析和解读
|
||||||
|
|
||||||
|
**阶段7:最终报告(使用 analysis_complete 动作)**
|
||||||
|
- 当所有分析工作完成后,生成最终的分析报告
|
||||||
|
- 包含对所有图片、模型和分析结果的综合总结
|
||||||
|
- 提供业务建议和预测洞察
|
||||||
|
|
||||||
|
代码生成规则:
|
||||||
|
1. 每次只专注一个阶段,不要试图一次性完成所有任务
|
||||||
|
2. 基于实际的数据结构而不是假设来编写代码
|
||||||
|
3. Notebook环境中变量会保持,避免重复导入和重复加载相同数据
|
||||||
|
4. 处理错误时,分析具体的错误信息并针对性修复,重新进行改阶段步骤,中途不要跳步骤
|
||||||
|
5. 图片保存使用会话目录变量:session_output_dir
|
||||||
|
6. 图表标题和标签使用中文,使用系统配置的中文字体显示
|
||||||
|
7. 必须打印绝对路径:每次保存图片后,使用os.path.abspath()打印完整的绝对路径
|
||||||
|
8. 图片文件名:同时打印图片的文件名,方便后续收集时识别
|
||||||
|
9. 饼图绘图代码生成必须遵守规则:类别 ≤ 5个:使用饼图 (plt.pie) + 外部图例,百分比标签清晰显示;类别 6-10个:使用水平条形图 (plt.barh) 便于阅读;类别 > 10个:使用排序条形图 + 合并小类别为"其他";学术美学要求**:白色背景、合适颜色、清晰标签、无冗余边框;
|
||||||
|
|
||||||
|
动作选择指南:
|
||||||
|
- **需要执行Python代码** → 使用 "generate_code"
|
||||||
|
- **已生成多个图表,需要收集分析** → 使用 "collect_figures"
|
||||||
|
- **所有分析完成,输出最终报告** → 使用 "analysis_complete"
|
||||||
|
- **遇到错误需要修复代码** → 使用 "generate_code"
|
||||||
|
|
||||||
|
高级分析技术指南(主动探索模式):
|
||||||
|
- **智能选择算法**:
|
||||||
|
- 遇到时间字段 -> `pd.to_datetime` -> 重采样 -> 移动平均/指数平滑/回归预测
|
||||||
|
- 遇到多数值特征 -> `StandardScaler` -> `KMeans` (使用Elbow法则选k) -> `PCA`降维可视化
|
||||||
|
- 遇到目标变量 -> `Correlation Matrix` -> `RandomForest` (feature_importances_)
|
||||||
|
- **文本挖掘**:
|
||||||
|
- 必须构建**专用停用词表** (Stop Words),过滤掉无效词汇:
|
||||||
|
- 年份/数字:2023, 2024, 2025, 1月, 2月...
|
||||||
|
- 通用动词:work, fix, support, issue, problem, check, test...
|
||||||
|
- 通用介词/代词:the, is, at, which, on, for, this, that...
|
||||||
|
- 仅保留具有实际业务含义的名词/动词短语(如 "connection timeout", "login failed")。
|
||||||
|
- **异常值挖掘**:总是检查是否存在显著偏离均值的异常点,并标记出来进行个案分析。
|
||||||
|
- **可视化增强**:不要只画折线图。使用 `seaborn` 的 `pairplot`, `heatmap`, `lmplot` 等高级图表。
|
||||||
|
|
||||||
|
可用分析库:
|
||||||
|
|
||||||
|
图片收集要求:
|
||||||
|
- 在适当的时候(通常是生成了多个图表后),主动使用 `collect_figures` 动作
|
||||||
|
- 收集时必须包含具体的图片绝对路径(file_path字段)
|
||||||
|
- 提供详细的图片描述和深入的分析
|
||||||
|
- 确保图片路径与之前打印的路径一致
|
||||||
|
|
||||||
|
报告生成要求:
|
||||||
|
- 生成的报告要符合报告的文言需要,不要出现有争议的文字
|
||||||
|
- 在适当的时候(通常是生成了多个图表后),进行图像的对比分析
|
||||||
|
- 涉及的文言,不能出现我,你,他,等主观用于,采用报告式的文言论述
|
||||||
|
- 提供详细的图片描述和深入的分析
|
||||||
|
- 报告中的英文单词,初专有名词(TSP,TBOX等),其余的全部翻译成中文,例如remote control(远控),don't exist in TSP (数据不在TSP上);
|
||||||
|
|
||||||
|
三种动作类型及使用时机:
|
||||||
|
|
||||||
|
**1. 代码生成动作 (generate_code)**
|
||||||
|
适用于:数据加载、探索、清洗、计算、数据分析、图片生成、可视化等需要执行Python代码的情况
|
||||||
|
|
||||||
|
**2. 图片收集动作 (collect_figures)**
|
||||||
|
适用于:已生成多个图表后,需要对图片进行汇总和深入分析的情况
|
||||||
|
|
||||||
|
**3. 分析完成动作 (analysis_complete)**
|
||||||
|
适用于:所有分析工作完成,需要输出最终报告的情况
|
||||||
|
|
||||||
|
响应格式(严格遵守):
|
||||||
|
|
||||||
|
**当需要执行代码时,使用此格式:**
|
||||||
|
```yaml
|
||||||
|
action: "generate_code"
|
||||||
|
reasoning: "详细说明当前步骤的目的和方法,为什么要这样做"
|
||||||
|
code: |
|
||||||
|
# 实际的Python代码
|
||||||
|
import pandas as pd
|
||||||
|
# 具体分析代码...
|
||||||
|
|
||||||
|
# 图片保存示例(如果生成图表)
|
||||||
|
plt.figure(figsize=(10, 6))
|
||||||
|
# 绘图代码...
|
||||||
|
plt.title('图表标题')
|
||||||
|
file_path = os.path.join(session_output_dir, '图表名称.png')
|
||||||
|
plt.savefig(file_path, dpi=150, bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
# 必须打印绝对路径
|
||||||
|
absolute_path = os.path.abspath(file_path)
|
||||||
|
print(f"图片已保存至: {{absolute_path}}")
|
||||||
|
print(f"图片文件名: {{os.path.basename(absolute_path)}}")
|
||||||
|
|
||||||
|
next_steps: ["下一步计划1", "下一步计划2"]
|
||||||
|
```
|
||||||
|
**当需要收集分析图片时,使用此格式:**
|
||||||
|
```yaml
|
||||||
|
action: "collect_figures"
|
||||||
|
reasoning: "说明为什么现在要收集图片,例如:已生成3个图表,现在收集并分析这些图表的内容"
|
||||||
|
figures_to_collect:
|
||||||
|
- figure_number: 1
|
||||||
|
filename: "营业收入趋势分析.png"
|
||||||
|
file_path: "实际的完整绝对路径"
|
||||||
|
description: "图片概述:展示了什么内容"
|
||||||
|
analysis: "细节分析:从图中可以看出的具体信息和洞察"
|
||||||
|
next_steps: ["后续计划"]
|
||||||
|
```
|
||||||
|
|
||||||
|
**当所有分析完成时,使用此格式:**
|
||||||
|
```yaml
|
||||||
|
action: "analysis_complete"
|
||||||
|
final_report: |
|
||||||
|
完整的最终分析报告内容
|
||||||
|
(可以是多行文本)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
特别注意:
|
||||||
|
- 数据读取问题:如果看到大量NaN值,检查编码和分隔符
|
||||||
|
- 日期列问题:如果日期列显示为float64,说明解析失败
|
||||||
|
- 编码错误:逐个尝试 ['utf-8', 'gbk', 'gb18030', 'gb2312', 'latin1']
|
||||||
|
- 列类型错误:检查是否有列被错误识别为数值型但实际是文本
|
||||||
|
- matplotlib错误时,确保使用Agg后端和正确的字体设置
|
||||||
|
- 每次执行后根据反馈调整代码,不要重复相同的错误
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 最终报告生成提示词
|
||||||
|
final_report_system_prompt = """你是一个专业的数据分析师,需要基于完整的分析过程生成最终的分析报告。
|
||||||
|
|
||||||
|
分析信息:
|
||||||
|
分析轮数: {current_round}
|
||||||
|
输出目录: {session_output_dir}
|
||||||
|
|
||||||
|
{figures_summary}
|
||||||
|
|
||||||
|
代码执行结果摘要:
|
||||||
|
{code_results_summary}
|
||||||
|
|
||||||
|
报告生成要求:
|
||||||
|
报告应使用markdown格式,确保结构清晰;需要包含对所有生成图片的详细分析和说明;
|
||||||
|
生成的报告要符合报告的文言需要,不要出现有争议的文字;
|
||||||
|
在适当的时候(通常是生成了多个图表后),进行图像的对比分析;
|
||||||
|
涉及的文言,不能出现我,你,他,等主观用于,采用报告式的文言论述;
|
||||||
|
提供详细的图片描述和深入的分析;
|
||||||
|
报告中的英文单词,初专有名词(TSP,TBOX等),其余的全部翻译成中文,例如remote control(远控),don't exist in TSP (数据不在TSP上);
|
||||||
|
|
||||||
|
总结分析过程中的关键发现;提供有价值的结论和建议;内容必须专业且逻辑性强。
|
||||||
|
**重要提醒:图片引用必须使用相对路径格式 ``**
|
||||||
|
|
||||||
|
图片质量与格式要求:
|
||||||
|
- **学术级图表标准**:所有图表必须达到发表级质量,包含:
|
||||||
|
* 专业的颜色方案(seaborn调色板)
|
||||||
|
* 清晰的标签和图例(无重叠)
|
||||||
|
* 合适的字体大小(≥12pt)
|
||||||
|
* 简洁的布局(白色背景,无冗余元素)
|
||||||
|
- **路径格式**:使用相对路径``
|
||||||
|
- **图表命名**:使用描述性中文名称,如`来源渠道分布.png`
|
||||||
|
响应格式要求:
|
||||||
|
必须严格使用以下YAML格式输出:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
action: "analysis_complete"
|
||||||
|
final_report: |
|
||||||
|
# 数据分析报告
|
||||||
|
|
||||||
|
## 分析概述
|
||||||
|
[概述本次分析的目标和范围]
|
||||||
|
|
||||||
|
## 数据分析过程
|
||||||
|
[总结分析的主要步骤]
|
||||||
|
|
||||||
|
## 关键发现
|
||||||
|
[描述重要的分析结果,使用段落形式而非列表]
|
||||||
|
|
||||||
|
## 图表分析
|
||||||
|
|
||||||
|
### [图表标题]
|
||||||
|

|
||||||
|
|
||||||
|
[对图表的详细分析,使用连续的段落描述,避免使用分点列表]
|
||||||
|
|
||||||
|
### [下一个图表标题]
|
||||||
|

|
||||||
|
|
||||||
|
[对图表的详细分析,使用连续的段落描述]
|
||||||
|
|
||||||
|
## 深度分析
|
||||||
|
### [图表标题]
|
||||||
|

|
||||||
|
|
||||||
|
[对此前所有的数据,探索关联关系,进行深度剖析,重点问题,高频问题,并以图表介绍,使用连续的段落描述,避免使用分点列表]
|
||||||
|
|
||||||
|
## 结论与建议
|
||||||
|
[基于分析结果提出结论和投资建议,使用段落形式表达]
|
||||||
|
```
|
||||||
|
|
||||||
|
特别注意事项:
|
||||||
|
必须对每个图片进行详细的分析和说明。
|
||||||
|
图片的内容和标题必须与分析内容相关。
|
||||||
|
使用专业的金融分析术语和方法。
|
||||||
|
报告要完整、准确、有价值。
|
||||||
|
**强制要求:所有图片路径都必须使用相对路径格式 `./文件名.png`。
|
||||||
|
为了确保后续markdown转换docx效果良好,请避免在正文中使用分点列表形式,改用段落形式表达。**
|
||||||
|
"""
|
||||||
52
requirements.txt
Normal file
52
requirements.txt
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
# 数据分析和科学计算库
|
||||||
|
pandas>=2.0.0
|
||||||
|
openpyxl>=3.1.0
|
||||||
|
numpy>=1.24.0
|
||||||
|
matplotlib>=3.6.0
|
||||||
|
duckdb>=0.8.0
|
||||||
|
scipy>=1.10.0
|
||||||
|
scikit-learn>=1.3.0
|
||||||
|
|
||||||
|
# Web和API相关
|
||||||
|
requests>=2.28.0
|
||||||
|
urllib3>=1.26.0
|
||||||
|
|
||||||
|
# 绘图和可视化
|
||||||
|
plotly>=5.14.0
|
||||||
|
dash>=2.0.0
|
||||||
|
|
||||||
|
# 流程图支持(可选,用于生成Mermaid图表)
|
||||||
|
# 注意:Mermaid图表主要在markdown中渲染,不需要额外的Python包
|
||||||
|
# 如果需要在Python中生成Mermaid代码,可以考虑:
|
||||||
|
# mermaid-py>=0.3.0
|
||||||
|
|
||||||
|
# Jupyter/IPython环境
|
||||||
|
ipython>=8.10.0
|
||||||
|
jupyter>=1.0.0
|
||||||
|
|
||||||
|
# AI/LLM相关
|
||||||
|
openai>=1.0.0
|
||||||
|
pyyaml>=6.0
|
||||||
|
|
||||||
|
# 配置管理
|
||||||
|
python-dotenv>=1.0.0
|
||||||
|
|
||||||
|
# 异步编程
|
||||||
|
asyncio-mqtt>=0.11.1
|
||||||
|
nest_asyncio>=1.5.0
|
||||||
|
|
||||||
|
# 文档生成(基于输出的Word文档)
|
||||||
|
python-docx>=0.8.11
|
||||||
|
|
||||||
|
# 系统和工具库
|
||||||
|
pathlib2>=2.3.7
|
||||||
|
typing-extensions>=4.5.0
|
||||||
|
|
||||||
|
# 开发和测试工具(可选)
|
||||||
|
pytest>=7.0.0
|
||||||
|
pytest-asyncio>=0.21.0
|
||||||
|
black>=23.0.0
|
||||||
|
flake8>=6.0.0
|
||||||
|
|
||||||
|
# 字体支持(用于matplotlib中文显示)
|
||||||
|
fonttools>=4.38.0
|
||||||
10
utils/__init__.py
Normal file
10
utils/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
工具模块初始化文件
|
||||||
|
"""
|
||||||
|
|
||||||
|
from utils.code_executor import CodeExecutor
|
||||||
|
from utils.llm_helper import LLMHelper
|
||||||
|
from utils.fallback_openai_client import AsyncFallbackOpenAIClient
|
||||||
|
|
||||||
|
__all__ = ["CodeExecutor", "LLMHelper", "AsyncFallbackOpenAIClient"]
|
||||||
453
utils/code_executor.py
Normal file
453
utils/code_executor.py
Normal file
@@ -0,0 +1,453 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
安全的代码执行器,基于 IPython 提供 notebook 环境下的代码执行功能
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import ast
|
||||||
|
import traceback
|
||||||
|
import io
|
||||||
|
from typing import Dict, Any, List, Optional, Tuple
|
||||||
|
from contextlib import redirect_stdout, redirect_stderr
|
||||||
|
from IPython.core.interactiveshell import InteractiveShell
|
||||||
|
from IPython.utils.capture import capture_output
|
||||||
|
import matplotlib
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import matplotlib.font_manager as fm
|
||||||
|
|
||||||
|
|
||||||
|
class CodeExecutor:
|
||||||
|
"""
|
||||||
|
安全的代码执行器,限制依赖库,捕获输出,支持图片保存与路径输出
|
||||||
|
"""
|
||||||
|
|
||||||
|
ALLOWED_IMPORTS = {
|
||||||
|
"pandas",
|
||||||
|
"pd",
|
||||||
|
"numpy",
|
||||||
|
"np",
|
||||||
|
"matplotlib",
|
||||||
|
"matplotlib.pyplot",
|
||||||
|
"plt",
|
||||||
|
"seaborn",
|
||||||
|
"sns",
|
||||||
|
"duckdb",
|
||||||
|
"scipy",
|
||||||
|
"sklearn",
|
||||||
|
"plotly",
|
||||||
|
"dash",
|
||||||
|
"requests",
|
||||||
|
"urllib",
|
||||||
|
"os",
|
||||||
|
"sys",
|
||||||
|
"json",
|
||||||
|
"csv",
|
||||||
|
"datetime",
|
||||||
|
"time",
|
||||||
|
"math",
|
||||||
|
"statistics",
|
||||||
|
"re",
|
||||||
|
"pathlib",
|
||||||
|
"io",
|
||||||
|
"collections",
|
||||||
|
"itertools",
|
||||||
|
"functools",
|
||||||
|
"operator",
|
||||||
|
"warnings",
|
||||||
|
"logging",
|
||||||
|
"copy",
|
||||||
|
"pickle",
|
||||||
|
"gzip",
|
||||||
|
"zipfile",
|
||||||
|
"yaml",
|
||||||
|
"typing",
|
||||||
|
"dataclasses",
|
||||||
|
"enum",
|
||||||
|
"sqlite3",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, output_dir: str = "outputs"):
|
||||||
|
"""
|
||||||
|
初始化代码执行器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_dir: 输出目录,用于保存图片和文件
|
||||||
|
"""
|
||||||
|
self.output_dir = os.path.abspath(output_dir)
|
||||||
|
os.makedirs(self.output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 初始化 IPython shell
|
||||||
|
self.shell = InteractiveShell.instance()
|
||||||
|
|
||||||
|
# 设置中文字体
|
||||||
|
self._setup_chinese_font()
|
||||||
|
|
||||||
|
# 预导入常用库
|
||||||
|
self._setup_common_imports()
|
||||||
|
|
||||||
|
# 图片计数器
|
||||||
|
self.image_counter = 0
|
||||||
|
|
||||||
|
def _setup_chinese_font(self):
|
||||||
|
"""设置matplotlib中文字体显示"""
|
||||||
|
try:
|
||||||
|
# 设置matplotlib使用Agg backend避免GUI问题
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
|
||||||
|
# 获取系统可用字体
|
||||||
|
available_fonts = [f.name for f in fm.fontManager.ttflist]
|
||||||
|
|
||||||
|
# 设置matplotlib使用系统可用中文字体
|
||||||
|
# macOS系统常用中文字体(按优先级排序)
|
||||||
|
chinese_fonts = [
|
||||||
|
"Hiragino Sans GB", # macOS中文简体
|
||||||
|
"Songti SC", # macOS宋体简体
|
||||||
|
"PingFang SC", # macOS苹方简体
|
||||||
|
"Heiti SC", # macOS黑体简体
|
||||||
|
"Heiti TC", # macOS黑体繁体
|
||||||
|
"PingFang HK", # macOS苹方香港
|
||||||
|
"SimHei", # Windows黑体
|
||||||
|
"STHeiti", # 华文黑体
|
||||||
|
"WenQuanYi Micro Hei", # Linux文泉驿微米黑
|
||||||
|
"DejaVu Sans", # 默认无衬线字体
|
||||||
|
"Arial Unicode MS", # Arial Unicode
|
||||||
|
]
|
||||||
|
|
||||||
|
# 检查系统中实际存在的字体
|
||||||
|
system_chinese_fonts = [
|
||||||
|
font for font in chinese_fonts if font in available_fonts
|
||||||
|
]
|
||||||
|
|
||||||
|
# 如果没有找到合适的中文字体,尝试更宽松的搜索
|
||||||
|
if not system_chinese_fonts:
|
||||||
|
print("警告:未找到精确匹配的中文字体,尝试更宽松的搜索...")
|
||||||
|
# 更宽松的字体匹配(包含部分名称)
|
||||||
|
fallback_fonts = []
|
||||||
|
for available_font in available_fonts:
|
||||||
|
if any(
|
||||||
|
keyword in available_font
|
||||||
|
for keyword in [
|
||||||
|
"Hei",
|
||||||
|
"Song",
|
||||||
|
"Fang",
|
||||||
|
"Kai",
|
||||||
|
"Hiragino",
|
||||||
|
"PingFang",
|
||||||
|
"ST",
|
||||||
|
]
|
||||||
|
):
|
||||||
|
fallback_fonts.append(available_font)
|
||||||
|
|
||||||
|
if fallback_fonts:
|
||||||
|
system_chinese_fonts = fallback_fonts[:3] # 取前3个匹配的字体
|
||||||
|
print(f"找到备选中文字体: {system_chinese_fonts}")
|
||||||
|
else:
|
||||||
|
print("警告:系统中未找到合适的中文字体,使用系统默认字体")
|
||||||
|
system_chinese_fonts = ["DejaVu Sans", "Arial Unicode MS"]
|
||||||
|
|
||||||
|
# 设置字体配置
|
||||||
|
plt.rcParams["font.sans-serif"] = system_chinese_fonts + [
|
||||||
|
"DejaVu Sans",
|
||||||
|
"Arial Unicode MS",
|
||||||
|
]
|
||||||
|
|
||||||
|
plt.rcParams["axes.unicode_minus"] = False
|
||||||
|
plt.rcParams["font.family"] = "sans-serif"
|
||||||
|
|
||||||
|
# 在shell中也设置相同的字体配置
|
||||||
|
font_list_str = str(
|
||||||
|
system_chinese_fonts + ["DejaVu Sans", "Arial Unicode MS"]
|
||||||
|
)
|
||||||
|
self.shell.run_cell(
|
||||||
|
f"""
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use('Agg')
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import matplotlib.font_manager as fm
|
||||||
|
|
||||||
|
# 设置中文字体
|
||||||
|
plt.rcParams['font.sans-serif'] = {font_list_str}
|
||||||
|
plt.rcParams['axes.unicode_minus'] = False
|
||||||
|
plt.rcParams['font.family'] = 'sans-serif'
|
||||||
|
|
||||||
|
# 确保matplotlib缓存目录可写
|
||||||
|
import os
|
||||||
|
cache_dir = os.path.expanduser('~/.matplotlib')
|
||||||
|
if not os.path.exists(cache_dir):
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
os.environ['MPLCONFIGDIR'] = cache_dir
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"设置中文字体失败: {e}")
|
||||||
|
# 即使失败也要设置基本的matplotlib配置
|
||||||
|
try:
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
plt.rcParams["axes.unicode_minus"] = False
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _setup_common_imports(self):
|
||||||
|
"""预导入常用库"""
|
||||||
|
common_imports = """
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import duckdb
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from IPython.display import display
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.shell.run_cell(common_imports)
|
||||||
|
# 确保display函数在shell的用户命名空间中可用
|
||||||
|
from IPython.display import display
|
||||||
|
|
||||||
|
self.shell.user_ns["display"] = display
|
||||||
|
except Exception as e:
|
||||||
|
print(f"预导入库失败: {e}")
|
||||||
|
|
||||||
|
def _check_code_safety(self, code: str) -> Tuple[bool, str]:
|
||||||
|
"""
|
||||||
|
检查代码安全性,限制导入的库
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(is_safe, error_message)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
tree = ast.parse(code)
|
||||||
|
except SyntaxError as e:
|
||||||
|
return False, f"语法错误: {e}"
|
||||||
|
|
||||||
|
for node in ast.walk(tree):
|
||||||
|
if isinstance(node, ast.Import):
|
||||||
|
for alias in node.names:
|
||||||
|
if alias.name not in self.ALLOWED_IMPORTS:
|
||||||
|
return False, f"不允许的导入: {alias.name}"
|
||||||
|
|
||||||
|
elif isinstance(node, ast.ImportFrom):
|
||||||
|
if node.module not in self.ALLOWED_IMPORTS:
|
||||||
|
return False, f"不允许的导入: {node.module}"
|
||||||
|
|
||||||
|
# 检查属性访问(防止通过os.system等方式绕过)
|
||||||
|
elif isinstance(node, ast.Attribute):
|
||||||
|
# 检查是否访问了os模块的属性
|
||||||
|
if isinstance(node.value, ast.Name) and node.value.id == "os":
|
||||||
|
# 允许的os子模块和函数白名单
|
||||||
|
allowed_os_attributes = {
|
||||||
|
"path", "environ", "getcwd", "listdir", "makedirs", "mkdir", "remove", "rmdir",
|
||||||
|
"path.join", "path.exists", "path.abspath", "path.dirname",
|
||||||
|
"path.basename", "path.splitext", "path.isdir", "path.isfile",
|
||||||
|
"sep", "name", "linesep", "stat", "getpid"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 检查直接属性访问 (如 os.getcwd)
|
||||||
|
if node.attr not in allowed_os_attributes:
|
||||||
|
# 进一步检查如果是 os.path.xxx 这种形式
|
||||||
|
# Note: ast.Attribute 嵌套结构比较复杂,简单处理只允许 os.path 和上述白名单
|
||||||
|
if node.attr == "path":
|
||||||
|
pass # 允许访问 os.path
|
||||||
|
else:
|
||||||
|
return False, f"不允许的os属性访问: os.{node.attr}"
|
||||||
|
|
||||||
|
# 检查危险函数调用
|
||||||
|
elif isinstance(node, ast.Call):
|
||||||
|
if isinstance(node.func, ast.Name):
|
||||||
|
if node.func.id in ["exec", "eval", "open", "__import__"]:
|
||||||
|
return False, f"不允许的函数调用: {node.func.id}"
|
||||||
|
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
def get_current_figures_info(self) -> List[Dict[str, Any]]:
|
||||||
|
"""获取当前matplotlib图形信息,但不自动保存"""
|
||||||
|
figures_info = []
|
||||||
|
|
||||||
|
# 获取当前所有图形
|
||||||
|
fig_nums = plt.get_fignums()
|
||||||
|
|
||||||
|
for fig_num in fig_nums:
|
||||||
|
fig = plt.figure(fig_num)
|
||||||
|
if fig.get_axes(): # 只处理有内容的图形
|
||||||
|
figures_info.append(
|
||||||
|
{
|
||||||
|
"figure_number": fig_num,
|
||||||
|
"axes_count": len(fig.get_axes()),
|
||||||
|
"figure_size": fig.get_size_inches().tolist(),
|
||||||
|
"has_content": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return figures_info
|
||||||
|
|
||||||
|
def _format_table_output(self, obj: Any) -> str:
|
||||||
|
"""格式化表格输出,限制行数"""
|
||||||
|
if hasattr(obj, "shape") and hasattr(obj, "head"): # pandas DataFrame
|
||||||
|
rows, cols = obj.shape
|
||||||
|
print(f"\n数据表形状: {rows}行 x {cols}列")
|
||||||
|
print(f"列名: {list(obj.columns)}")
|
||||||
|
|
||||||
|
if rows <= 15:
|
||||||
|
return str(obj)
|
||||||
|
else:
|
||||||
|
head_part = obj.head(5)
|
||||||
|
tail_part = obj.tail(5)
|
||||||
|
return f"{head_part}\n...\n(省略 {rows-10} 行)\n...\n{tail_part}"
|
||||||
|
|
||||||
|
return str(obj)
|
||||||
|
|
||||||
|
def execute_code(self, code: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
执行代码并返回结果
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: 要执行的Python代码
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{
|
||||||
|
'success': bool,
|
||||||
|
'output': str,
|
||||||
|
'error': str,
|
||||||
|
'variables': Dict[str, Any] # 新生成的重要变量
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
# 检查代码安全性
|
||||||
|
is_safe, safety_error = self._check_code_safety(code)
|
||||||
|
if not is_safe:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"output": "",
|
||||||
|
"error": f"代码安全检查失败: {safety_error}",
|
||||||
|
"variables": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
# 记录执行前的变量
|
||||||
|
vars_before = set(self.shell.user_ns.keys())
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 使用IPython的capture_output来捕获所有输出
|
||||||
|
with capture_output() as captured:
|
||||||
|
result = self.shell.run_cell(code)
|
||||||
|
|
||||||
|
# 检查执行结果
|
||||||
|
if result.error_before_exec:
|
||||||
|
error_msg = str(result.error_before_exec)
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"output": captured.stdout,
|
||||||
|
"error": f"执行前错误: {error_msg}",
|
||||||
|
"variables": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.error_in_exec:
|
||||||
|
error_msg = str(result.error_in_exec)
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"output": captured.stdout,
|
||||||
|
"error": f"执行错误: {error_msg}",
|
||||||
|
"variables": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
# 获取输出
|
||||||
|
output = captured.stdout
|
||||||
|
|
||||||
|
# 如果有返回值,添加到输出
|
||||||
|
if result.result is not None:
|
||||||
|
formatted_result = self._format_table_output(result.result)
|
||||||
|
output += f"\n{formatted_result}"
|
||||||
|
# 记录新产生的重要变量(简化版本)
|
||||||
|
vars_after = set(self.shell.user_ns.keys())
|
||||||
|
new_vars = vars_after - vars_before
|
||||||
|
|
||||||
|
# 只记录新创建的DataFrame等重要数据结构
|
||||||
|
important_new_vars = {}
|
||||||
|
for var_name in new_vars:
|
||||||
|
if not var_name.startswith("_"):
|
||||||
|
try:
|
||||||
|
var_value = self.shell.user_ns[var_name]
|
||||||
|
if hasattr(var_value, "shape"): # pandas DataFrame, numpy array
|
||||||
|
important_new_vars[var_name] = (
|
||||||
|
f"{type(var_value).__name__} with shape {var_value.shape}"
|
||||||
|
)
|
||||||
|
elif var_name in ["session_output_dir"]: # 重要的配置变量
|
||||||
|
important_new_vars[var_name] = str(var_value)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"output": output,
|
||||||
|
"error": "",
|
||||||
|
"variables": important_new_vars,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"output": captured.stdout if "captured" in locals() else "",
|
||||||
|
"error": f"执行异常: {str(e)}\n{traceback.format_exc()}",
|
||||||
|
"variables": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
def reset_environment(self):
|
||||||
|
"""重置执行环境"""
|
||||||
|
self.shell.reset()
|
||||||
|
self._setup_common_imports()
|
||||||
|
self._setup_chinese_font()
|
||||||
|
plt.close("all")
|
||||||
|
self.image_counter = 0
|
||||||
|
|
||||||
|
def set_variable(self, name: str, value: Any):
|
||||||
|
"""设置执行环境中的变量"""
|
||||||
|
self.shell.user_ns[name] = value
|
||||||
|
|
||||||
|
def get_environment_info(self) -> str:
|
||||||
|
"""获取当前执行环境的变量信息,用于系统提示词"""
|
||||||
|
info_parts = []
|
||||||
|
|
||||||
|
# 获取重要的数据变量
|
||||||
|
important_vars = {}
|
||||||
|
for var_name, var_value in self.shell.user_ns.items():
|
||||||
|
if not var_name.startswith("_") and var_name not in [
|
||||||
|
"In",
|
||||||
|
"Out",
|
||||||
|
"get_ipython",
|
||||||
|
"exit",
|
||||||
|
"quit",
|
||||||
|
]:
|
||||||
|
try:
|
||||||
|
if hasattr(var_value, "shape"): # pandas DataFrame, numpy array
|
||||||
|
important_vars[var_name] = (
|
||||||
|
f"{type(var_value).__name__} with shape {var_value.shape}"
|
||||||
|
)
|
||||||
|
elif var_name in ["session_output_dir"]: # 重要的路径变量
|
||||||
|
important_vars[var_name] = str(var_value)
|
||||||
|
elif (
|
||||||
|
isinstance(var_value, (int, float, str, bool))
|
||||||
|
and len(str(var_value)) < 100
|
||||||
|
):
|
||||||
|
important_vars[var_name] = (
|
||||||
|
f"{type(var_value).__name__}: {var_value}"
|
||||||
|
)
|
||||||
|
elif hasattr(var_value, "__module__") and var_value.__module__ in [
|
||||||
|
"pandas",
|
||||||
|
"numpy",
|
||||||
|
"matplotlib.pyplot",
|
||||||
|
]:
|
||||||
|
important_vars[var_name] = f"导入的模块: {var_value.__module__}"
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if important_vars:
|
||||||
|
info_parts.append("当前环境变量:")
|
||||||
|
for var_name, var_info in important_vars.items():
|
||||||
|
info_parts.append(f"- {var_name}: {var_info}")
|
||||||
|
else:
|
||||||
|
info_parts.append("当前环境已预装pandas, numpy, matplotlib等库")
|
||||||
|
|
||||||
|
# 添加输出目录信息
|
||||||
|
if "session_output_dir" in self.shell.user_ns:
|
||||||
|
info_parts.append(
|
||||||
|
f"图片保存目录: session_output_dir = '{self.shell.user_ns['session_output_dir']}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
return "\n".join(info_parts)
|
||||||
15
utils/create_session_dir.py
Normal file
15
utils/create_session_dir.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
def create_session_output_dir(base_output_dir, user_input: str) -> str:
|
||||||
|
"""为本次分析创建独立的输出目录"""
|
||||||
|
|
||||||
|
# 使用当前时间创建唯一的会话目录名(格式:YYYYMMDD_HHMMSS)
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
session_id = timestamp
|
||||||
|
dir_name = f"session_{session_id}"
|
||||||
|
session_dir = os.path.join(base_output_dir, dir_name)
|
||||||
|
os.makedirs(session_dir, exist_ok=True)
|
||||||
|
|
||||||
|
return session_dir
|
||||||
90
utils/data_loader.py
Normal file
90
utils/data_loader.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import os
|
||||||
|
import pandas as pd
|
||||||
|
import io
|
||||||
|
|
||||||
|
def load_and_profile_data(file_paths: list) -> str:
|
||||||
|
"""
|
||||||
|
加载数据并生成数据画像
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_paths: 文件路径列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含数据画像的Markdown字符串
|
||||||
|
"""
|
||||||
|
profile_summary = "# 数据画像报告 (Data Profile)\n\n"
|
||||||
|
|
||||||
|
if not file_paths:
|
||||||
|
return profile_summary + "未提供数据文件。"
|
||||||
|
|
||||||
|
for file_path in file_paths:
|
||||||
|
file_name = os.path.basename(file_path)
|
||||||
|
profile_summary += f"## 文件: {file_name}\n\n"
|
||||||
|
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
profile_summary += f"⚠️ 文件不存在: {file_path}\n\n"
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 根据扩展名选择加载方式
|
||||||
|
ext = os.path.splitext(file_path)[1].lower()
|
||||||
|
if ext == '.csv':
|
||||||
|
# 尝试多种编码
|
||||||
|
try:
|
||||||
|
df = pd.read_csv(file_path, encoding='utf-8')
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
try:
|
||||||
|
df = pd.read_csv(file_path, encoding='gbk')
|
||||||
|
except Exception:
|
||||||
|
df = pd.read_csv(file_path, encoding='latin1')
|
||||||
|
elif ext in ['.xlsx', '.xls']:
|
||||||
|
df = pd.read_excel(file_path)
|
||||||
|
else:
|
||||||
|
profile_summary += f"⚠️ 不支持的文件格式: {ext}\n\n"
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 基础信息
|
||||||
|
rows, cols = df.shape
|
||||||
|
profile_summary += f"- **维度**: {rows} 行 x {cols} 列\n"
|
||||||
|
profile_summary += f"- **列名**: `{', '.join(df.columns)}`\n\n"
|
||||||
|
|
||||||
|
profile_summary += "### 列详细分布:\n"
|
||||||
|
|
||||||
|
# 遍历分析每列
|
||||||
|
for col in df.columns:
|
||||||
|
dtype = df[col].dtype
|
||||||
|
null_count = df[col].isnull().sum()
|
||||||
|
null_ratio = (null_count / rows) * 100
|
||||||
|
|
||||||
|
profile_summary += f"#### {col} ({dtype})\n"
|
||||||
|
if null_count > 0:
|
||||||
|
profile_summary += f"- ⚠️ 空值: {null_count} ({null_ratio:.1f}%)\n"
|
||||||
|
|
||||||
|
# 数值列分析
|
||||||
|
if pd.api.types.is_numeric_dtype(dtype):
|
||||||
|
desc = df[col].describe()
|
||||||
|
profile_summary += f"- 统计: Min={desc['min']:.2f}, Max={desc['max']:.2f}, Mean={desc['mean']:.2f}\n"
|
||||||
|
|
||||||
|
# 文本/分类列分析
|
||||||
|
elif pd.api.types.is_object_dtype(dtype) or pd.api.types.is_categorical_dtype(dtype):
|
||||||
|
unique_count = df[col].nunique()
|
||||||
|
profile_summary += f"- 唯一值数量: {unique_count}\n"
|
||||||
|
|
||||||
|
# 如果唯一值较少(<50)或者看起来是分类数据,显示Top分布
|
||||||
|
# 这对识别“高频问题”至关重要
|
||||||
|
if unique_count > 0:
|
||||||
|
top_n = df[col].value_counts().head(5)
|
||||||
|
top_items_str = ", ".join([f"{k}({v})" for k, v in top_n.items()])
|
||||||
|
profile_summary += f"- **TOP 5 高频值**: {top_items_str}\n"
|
||||||
|
|
||||||
|
# 时间列分析
|
||||||
|
elif pd.api.types.is_datetime64_any_dtype(dtype):
|
||||||
|
profile_summary += f"- 范围: {df[col].min()} 至 {df[col].max()}\n"
|
||||||
|
|
||||||
|
profile_summary += "\n"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
profile_summary += f"❌ 读取或分析文件失败: {str(e)}\n\n"
|
||||||
|
|
||||||
|
return profile_summary
|
||||||
38
utils/extract_code.py
Normal file
38
utils/extract_code.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
from typing import Optional
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
def extract_code_from_response(response: str) -> Optional[str]:
|
||||||
|
"""从LLM响应中提取代码"""
|
||||||
|
try:
|
||||||
|
# 尝试解析YAML
|
||||||
|
if '```yaml' in response:
|
||||||
|
start = response.find('```yaml') + 7
|
||||||
|
end = response.find('```', start)
|
||||||
|
yaml_content = response[start:end].strip()
|
||||||
|
elif '```' in response:
|
||||||
|
start = response.find('```') + 3
|
||||||
|
end = response.find('```', start)
|
||||||
|
yaml_content = response[start:end].strip()
|
||||||
|
else:
|
||||||
|
yaml_content = response.strip()
|
||||||
|
|
||||||
|
yaml_data = yaml.safe_load(yaml_content)
|
||||||
|
if 'code' in yaml_data:
|
||||||
|
return yaml_data['code']
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 如果YAML解析失败,尝试提取```python代码块
|
||||||
|
if '```python' in response:
|
||||||
|
start = response.find('```python') + 9
|
||||||
|
end = response.find('```', start)
|
||||||
|
if end != -1:
|
||||||
|
return response[start:end].strip()
|
||||||
|
elif '```' in response:
|
||||||
|
start = response.find('```') + 3
|
||||||
|
end = response.find('```', start)
|
||||||
|
if end != -1:
|
||||||
|
return response[start:end].strip()
|
||||||
|
|
||||||
|
return None
|
||||||
230
utils/fallback_openai_client.py
Normal file
230
utils/fallback_openai_client.py
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import asyncio
|
||||||
|
from typing import Optional, Any, Mapping, Dict
|
||||||
|
from openai import AsyncOpenAI, APIStatusError, APIConnectionError, APITimeoutError, APIError
|
||||||
|
from openai.types.chat import ChatCompletion
|
||||||
|
|
||||||
|
class AsyncFallbackOpenAIClient:
|
||||||
|
"""
|
||||||
|
一个支持备用 API 自动切换的异步 OpenAI 客户端。
|
||||||
|
当主 API 调用因特定错误(如内容过滤)失败时,会自动尝试使用备用 API。
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
primary_api_key: str,
|
||||||
|
primary_base_url: str,
|
||||||
|
primary_model_name: str,
|
||||||
|
fallback_api_key: Optional[str] = None,
|
||||||
|
fallback_base_url: Optional[str] = None,
|
||||||
|
fallback_model_name: Optional[str] = None,
|
||||||
|
primary_client_args: Optional[Dict[str, Any]] = None,
|
||||||
|
fallback_client_args: Optional[Dict[str, Any]] = None,
|
||||||
|
content_filter_error_code: str = "1301", # 特定于 Zhipu 的内容过滤错误代码
|
||||||
|
content_filter_error_field: str = "contentFilter", # 特定于 Zhipu 的内容过滤错误字段
|
||||||
|
max_retries_primary: int = 1, # 主API重试次数
|
||||||
|
max_retries_fallback: int = 1, # 备用API重试次数
|
||||||
|
retry_delay_seconds: float = 1.0 # 重试延迟时间
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化 AsyncFallbackOpenAIClient。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
primary_api_key: 主 API 的密钥。
|
||||||
|
primary_base_url: 主 API 的基础 URL。
|
||||||
|
primary_model_name: 主 API 使用的模型名称。
|
||||||
|
fallback_api_key: 备用 API 的密钥 (可选)。
|
||||||
|
fallback_base_url: 备用 API 的基础 URL (可选)。
|
||||||
|
fallback_model_name: 备用 API 使用的模型名称 (可选)。
|
||||||
|
primary_client_args: 传递给主 AsyncOpenAI 客户端的其他参数。
|
||||||
|
fallback_client_args: 传递给备用 AsyncOpenAI 客户端的其他参数。
|
||||||
|
content_filter_error_code: 触发回退的内容过滤错误的特定错误代码。
|
||||||
|
content_filter_error_field: 触发回退的内容过滤错误中存在的字段名。
|
||||||
|
max_retries_primary: 主 API 失败时的最大重试次数。
|
||||||
|
max_retries_fallback: 备用 API 失败时的最大重试次数。
|
||||||
|
retry_delay_seconds: 重试前的延迟时间(秒)。
|
||||||
|
"""
|
||||||
|
if not primary_api_key or not primary_base_url:
|
||||||
|
raise ValueError("主 API 密钥和基础 URL 不能为空。")
|
||||||
|
|
||||||
|
_primary_args = primary_client_args or {}
|
||||||
|
self.primary_client = AsyncOpenAI(api_key=primary_api_key, base_url=primary_base_url, **_primary_args)
|
||||||
|
self.primary_model_name = primary_model_name
|
||||||
|
|
||||||
|
self.fallback_client: Optional[AsyncOpenAI] = None
|
||||||
|
self.fallback_model_name: Optional[str] = None
|
||||||
|
if fallback_api_key and fallback_base_url and fallback_model_name:
|
||||||
|
_fallback_args = fallback_client_args or {}
|
||||||
|
self.fallback_client = AsyncOpenAI(api_key=fallback_api_key, base_url=fallback_base_url, **_fallback_args)
|
||||||
|
self.fallback_model_name = fallback_model_name
|
||||||
|
else:
|
||||||
|
print("⚠️ 警告: 未完全配置备用 API 客户端。如果主 API 失败,将无法进行回退。")
|
||||||
|
|
||||||
|
self.content_filter_error_code = content_filter_error_code
|
||||||
|
self.content_filter_error_field = content_filter_error_field
|
||||||
|
self.max_retries_primary = max_retries_primary
|
||||||
|
self.max_retries_fallback = max_retries_fallback
|
||||||
|
self.retry_delay_seconds = retry_delay_seconds
|
||||||
|
self._closed = False
|
||||||
|
|
||||||
|
async def _attempt_api_call(
|
||||||
|
self,
|
||||||
|
client: AsyncOpenAI,
|
||||||
|
model_name: str,
|
||||||
|
messages: list[Mapping[str, Any]],
|
||||||
|
max_retries: int,
|
||||||
|
api_name: str,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> ChatCompletion:
|
||||||
|
"""
|
||||||
|
尝试调用指定的 OpenAI API 客户端,并进行重试。
|
||||||
|
"""
|
||||||
|
last_exception = None
|
||||||
|
for attempt in range(max_retries + 1):
|
||||||
|
try:
|
||||||
|
# print(f"尝试使用 {api_name} API ({client.base_url}) 模型: {kwargs.get('model', model_name)}, 第 {attempt + 1} 次尝试")
|
||||||
|
completion = await client.chat.completions.create(
|
||||||
|
model=kwargs.pop('model', model_name),
|
||||||
|
messages=messages,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return completion
|
||||||
|
except (APIConnectionError, APITimeoutError) as e: # 通常可以重试的网络错误
|
||||||
|
last_exception = e
|
||||||
|
print(f"⚠️ {api_name} API 调用时发生可重试错误 ({type(e).__name__}): {e}. 尝试次数 {attempt + 1}/{max_retries + 1}")
|
||||||
|
if attempt < max_retries:
|
||||||
|
await asyncio.sleep(self.retry_delay_seconds * (attempt + 1)) # 增加延迟
|
||||||
|
else:
|
||||||
|
print(f"❌ {api_name} API 在达到最大重试次数后仍然失败。")
|
||||||
|
except APIStatusError as e: # API 返回的特定状态码错误
|
||||||
|
is_content_filter_error = False
|
||||||
|
if e.status_code == 400:
|
||||||
|
try:
|
||||||
|
error_json = e.response.json()
|
||||||
|
error_details = error_json.get("error", {})
|
||||||
|
if (error_details.get("code") == self.content_filter_error_code and
|
||||||
|
self.content_filter_error_field in error_json):
|
||||||
|
is_content_filter_error = True
|
||||||
|
except Exception:
|
||||||
|
pass # 解析错误响应失败,不认为是内容过滤错误
|
||||||
|
|
||||||
|
if is_content_filter_error and api_name == "主": # 如果是主 API 的内容过滤错误,则直接抛出以便回退
|
||||||
|
raise e
|
||||||
|
|
||||||
|
last_exception = e
|
||||||
|
print(f"⚠️ {api_name} API 调用时发生 APIStatusError ({e.status_code}): {e}. 尝试次数 {attempt + 1}/{max_retries + 1}")
|
||||||
|
if attempt < max_retries:
|
||||||
|
await asyncio.sleep(self.retry_delay_seconds * (attempt + 1))
|
||||||
|
else:
|
||||||
|
print(f"❌ {api_name} API 在达到最大重试次数后仍然失败 (APIStatusError)。")
|
||||||
|
except APIError as e: # 其他不可轻易重试的 OpenAI 错误
|
||||||
|
last_exception = e
|
||||||
|
print(f"❌ {api_name} API 调用时发生不可重试错误 ({type(e).__name__}): {e}")
|
||||||
|
break # 不再重试此类错误
|
||||||
|
|
||||||
|
if last_exception:
|
||||||
|
raise last_exception
|
||||||
|
raise RuntimeError(f"{api_name} API 调用意外失败。") # 理论上不应到达这里
|
||||||
|
|
||||||
|
async def chat_completions_create(
|
||||||
|
self,
|
||||||
|
messages: list[Mapping[str, Any]],
|
||||||
|
**kwargs: Any # 用于传递其他 OpenAI 参数,如 max_tokens, temperature 等。
|
||||||
|
) -> ChatCompletion:
|
||||||
|
"""
|
||||||
|
使用主 API 创建聊天补全,如果发生特定内容过滤错误或主 API 调用失败,则回退到备用 API。
|
||||||
|
支持对主 API 和备用 API 的可重试错误进行重试。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: OpenAI API 的消息列表。
|
||||||
|
**kwargs: 传递给 OpenAI API 调用的其他参数。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatCompletion 对象。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
APIError: 如果主 API 和备用 API (如果尝试) 都返回 API 错误。
|
||||||
|
RuntimeError: 如果客户端已关闭。
|
||||||
|
"""
|
||||||
|
if self._closed:
|
||||||
|
raise RuntimeError("客户端已关闭。")
|
||||||
|
|
||||||
|
try:
|
||||||
|
completion = await self._attempt_api_call(
|
||||||
|
client=self.primary_client,
|
||||||
|
model_name=self.primary_model_name,
|
||||||
|
messages=messages,
|
||||||
|
max_retries=self.max_retries_primary,
|
||||||
|
api_name="主",
|
||||||
|
**kwargs.copy()
|
||||||
|
)
|
||||||
|
return completion
|
||||||
|
except APIStatusError as e_primary:
|
||||||
|
is_content_filter_error = False
|
||||||
|
if e_primary.status_code == 400:
|
||||||
|
try:
|
||||||
|
error_json = e_primary.response.json()
|
||||||
|
error_details = error_json.get("error", {})
|
||||||
|
if (error_details.get("code") == self.content_filter_error_code and
|
||||||
|
self.content_filter_error_field in error_json):
|
||||||
|
is_content_filter_error = True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if is_content_filter_error and self.fallback_client and self.fallback_model_name:
|
||||||
|
print(f"ℹ️ 主 API 内容过滤错误 ({e_primary.status_code})。尝试切换到备用 API ({self.fallback_client.base_url})...")
|
||||||
|
try:
|
||||||
|
fallback_completion = await self._attempt_api_call(
|
||||||
|
client=self.fallback_client,
|
||||||
|
model_name=self.fallback_model_name,
|
||||||
|
messages=messages,
|
||||||
|
max_retries=self.max_retries_fallback,
|
||||||
|
api_name="备用",
|
||||||
|
**kwargs.copy()
|
||||||
|
)
|
||||||
|
print(f"✅ 备用 API 调用成功。")
|
||||||
|
return fallback_completion
|
||||||
|
except APIError as e_fallback:
|
||||||
|
print(f"❌ 备用 API 调用最终失败: {type(e_fallback).__name__} - {e_fallback}")
|
||||||
|
raise e_fallback
|
||||||
|
else:
|
||||||
|
if not (self.fallback_client and self.fallback_model_name and is_content_filter_error):
|
||||||
|
# 如果不是内容过滤错误,或者没有可用的备用API,则记录主API的原始错误
|
||||||
|
print(f"ℹ️ 主 API 错误 ({type(e_primary).__name__}: {e_primary}), 且不满足备用条件或备用API未配置。")
|
||||||
|
raise e_primary
|
||||||
|
except APIError as e_primary_other:
|
||||||
|
print(f"❌ 主 API 调用最终失败 (非内容过滤,错误类型: {type(e_primary_other).__name__}): {e_primary_other}")
|
||||||
|
if self.fallback_client and self.fallback_model_name:
|
||||||
|
print(f"ℹ️ 主 API 失败,尝试切换到备用 API ({self.fallback_client.base_url})...")
|
||||||
|
try:
|
||||||
|
fallback_completion = await self._attempt_api_call(
|
||||||
|
client=self.fallback_client,
|
||||||
|
model_name=self.fallback_model_name,
|
||||||
|
messages=messages,
|
||||||
|
max_retries=self.max_retries_fallback,
|
||||||
|
api_name="备用",
|
||||||
|
**kwargs.copy()
|
||||||
|
)
|
||||||
|
print(f"✅ 备用 API 调用成功。")
|
||||||
|
return fallback_completion
|
||||||
|
except APIError as e_fallback_after_primary_fail:
|
||||||
|
print(f"❌ 备用 API 在主 API 失败后也调用失败: {type(e_fallback_after_primary_fail).__name__} - {e_fallback_after_primary_fail}")
|
||||||
|
raise e_fallback_after_primary_fail
|
||||||
|
else:
|
||||||
|
raise e_primary_other
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""异步关闭主客户端和备用客户端 (如果存在)。"""
|
||||||
|
if not self._closed:
|
||||||
|
await self.primary_client.close()
|
||||||
|
if self.fallback_client:
|
||||||
|
await self.fallback_client.close()
|
||||||
|
self._closed = True
|
||||||
|
# print("AsyncFallbackOpenAIClient 已关闭。")
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
if self._closed:
|
||||||
|
raise RuntimeError("AsyncFallbackOpenAIClient 不能在关闭后重新进入。请创建一个新实例。")
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
await self.close()
|
||||||
25
utils/format_execution_result.py
Normal file
25
utils/format_execution_result.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
|
||||||
|
def format_execution_result(result: Dict[str, Any]) -> str:
|
||||||
|
"""格式化执行结果为用户可读的反馈"""
|
||||||
|
feedback = []
|
||||||
|
|
||||||
|
if result['success']:
|
||||||
|
feedback.append("✅ 代码执行成功")
|
||||||
|
|
||||||
|
if result['output']:
|
||||||
|
feedback.append(f"📊 输出结果:\n{result['output']}")
|
||||||
|
|
||||||
|
if result.get('variables'):
|
||||||
|
feedback.append("📋 新生成的变量:")
|
||||||
|
for var_name, var_info in result['variables'].items():
|
||||||
|
feedback.append(f" - {var_name}: {var_info}")
|
||||||
|
else:
|
||||||
|
feedback.append("❌ 代码执行失败")
|
||||||
|
feedback.append(f"错误信息: {result['error']}")
|
||||||
|
if result['output']:
|
||||||
|
feedback.append(f"部分输出: {result['output']}")
|
||||||
|
|
||||||
|
return "\n".join(feedback)
|
||||||
86
utils/llm_helper.py
Normal file
86
utils/llm_helper.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
LLM调用辅助模块
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import yaml
|
||||||
|
from config.llm_config import LLMConfig
|
||||||
|
from utils.fallback_openai_client import AsyncFallbackOpenAIClient
|
||||||
|
|
||||||
|
class LLMHelper:
|
||||||
|
"""LLM调用辅助类,支持同步和异步调用"""
|
||||||
|
|
||||||
|
def __init__(self, config: LLMConfig = None):
|
||||||
|
self.config = config
|
||||||
|
self.client = AsyncFallbackOpenAIClient(
|
||||||
|
primary_api_key=config.api_key,
|
||||||
|
primary_base_url=config.base_url,
|
||||||
|
primary_model_name=config.model
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_call(self, prompt: str, system_prompt: str = None, max_tokens: int = None, temperature: float = None) -> str:
|
||||||
|
"""异步调用LLM"""
|
||||||
|
messages = []
|
||||||
|
if system_prompt:
|
||||||
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
if max_tokens is not None:
|
||||||
|
kwargs['max_tokens'] = max_tokens
|
||||||
|
else:
|
||||||
|
kwargs['max_tokens'] = self.config.max_tokens
|
||||||
|
|
||||||
|
if temperature is not None:
|
||||||
|
kwargs['temperature'] = temperature
|
||||||
|
else:
|
||||||
|
kwargs['temperature'] = self.config.temperature
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self.client.chat_completions_create(
|
||||||
|
messages=messages,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return response.choices[0].message.content
|
||||||
|
except Exception as e:
|
||||||
|
print(f"LLM调用失败: {e}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def call(self, prompt: str, system_prompt: str = None, max_tokens: int = None, temperature: float = None) -> str:
|
||||||
|
"""同步调用LLM"""
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
import nest_asyncio
|
||||||
|
nest_asyncio.apply()
|
||||||
|
|
||||||
|
return loop.run_until_complete(self.async_call(prompt, system_prompt, max_tokens, temperature))
|
||||||
|
|
||||||
|
def parse_yaml_response(self, response: str) -> dict:
|
||||||
|
"""解析YAML格式的响应"""
|
||||||
|
try:
|
||||||
|
# 提取```yaml和```之间的内容
|
||||||
|
if '```yaml' in response:
|
||||||
|
start = response.find('```yaml') + 7
|
||||||
|
end = response.find('```', start)
|
||||||
|
yaml_content = response[start:end].strip()
|
||||||
|
elif '```' in response:
|
||||||
|
start = response.find('```') + 3
|
||||||
|
end = response.find('```', start)
|
||||||
|
yaml_content = response[start:end].strip()
|
||||||
|
else:
|
||||||
|
yaml_content = response.strip()
|
||||||
|
|
||||||
|
return yaml.safe_load(yaml_content)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"YAML解析失败: {e}")
|
||||||
|
print(f"原始响应: {response}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""关闭客户端"""
|
||||||
|
await self.client.close()
|
||||||
Reference in New Issue
Block a user