From 7f46f25a4bc35b83049e23a92dc0b3be99207cc2 Mon Sep 17 00:00:00 2001 From: Zhaojie Date: Tue, 6 Jan 2026 14:09:12 +0800 Subject: [PATCH] feat: implement initial structure and core components for data analysis agent --- .env.example | 8 + .gitignore | 173 +++++++++++ LICENSE | 21 ++ __init__.py | 54 ++++ config/__init__.py | 8 + config/llm_config.py | 44 +++ data_analysis_agent.py | 483 +++++++++++++++++++++++++++++++ main.py | 18 ++ prompts.py | 286 ++++++++++++++++++ requirements.txt | 52 ++++ utils/__init__.py | 10 + utils/code_executor.py | 453 +++++++++++++++++++++++++++++ utils/create_session_dir.py | 15 + utils/data_loader.py | 90 ++++++ utils/extract_code.py | 38 +++ utils/fallback_openai_client.py | 230 +++++++++++++++ utils/format_execution_result.py | 25 ++ utils/llm_helper.py | 86 ++++++ 18 files changed, 2094 insertions(+) create mode 100644 .env.example create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 __init__.py create mode 100644 config/__init__.py create mode 100644 config/llm_config.py create mode 100644 data_analysis_agent.py create mode 100644 main.py create mode 100644 prompts.py create mode 100644 requirements.txt create mode 100644 utils/__init__.py create mode 100644 utils/code_executor.py create mode 100644 utils/create_session_dir.py create mode 100644 utils/data_loader.py create mode 100644 utils/extract_code.py create mode 100644 utils/fallback_openai_client.py create mode 100644 utils/format_execution_result.py create mode 100644 utils/llm_helper.py diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..25df76c --- /dev/null +++ b/.env.example @@ -0,0 +1,8 @@ + +# 火山引擎配置 +OPENAI_API_KEY=sk-c44i1hy64xgzwox6x08o4zug93frq6rgn84oqugf2pje1tg4 +OPENAI_BASE_URL=https://api.xiaomimimo.com/v1/chat/completions +# 文本模型 +OPENAI_MODEL=mimo-v2-flash +# OPENAI_MODEL=deepseek-r1-250528 + diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..94f8d30 --- /dev/null +++ b/.gitignore @@ -0,0 +1,173 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# Project specific +# Output files and generated reports +outputs/ +*.png +*.jpg +*.jpeg +*.pdf +*.docx +*.xlsx +*.csv +!贵州茅台利润表.csv + +# 允许assets目录下的图片文件(项目资源) +!assets/**/*.png +!assets/**/*.jpg +!assets/**/*.jpeg + +# IDE and editor files +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# OS specific files +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db + +# API keys and configuration +config.ini +.env +secrets.json +api_keys.txt + +# Temporary files +*.tmp +*.temp +*.log diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..5d9664b --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Data Analysis Agent Team + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..ff71db5 --- /dev/null +++ b/__init__.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +""" +Data Analysis Agent Package + +一个基于LLM的智能数据分析代理,专门为Jupyter Notebook环境设计。 +""" + +from .core.notebook_agent import NotebookAgent +from .config.llm_config import LLMConfig +from .utils.code_executor import CodeExecutor + +__version__ = "1.0.0" +__author__ = "Data Analysis Agent Team" + +# 主要导出类 +__all__ = [ + "NotebookAgent", + "LLMConfig", + "CodeExecutor", +] + +# 便捷函数 +def create_agent(config=None, output_dir="outputs", max_rounds=20, session_dir=None): + """ + 创建一个数据分析智能体实例 + + Args: + config: LLM配置,如果为None则使用默认配置 + output_dir: 输出目录 + max_rounds: 最大分析轮数 + session_dir: 指定会话目录(可选) + + Returns: + NotebookAgent: 智能体实例 + """ + if config is None: + config = LLMConfig() + return NotebookAgent(config=config, output_dir=output_dir, max_rounds=max_rounds, session_dir=session_dir) + +def quick_analysis(query, files=None, output_dir="outputs", max_rounds=10): + """ + 快速数据分析函数 + + Args: + query: 分析需求(自然语言) + files: 数据文件路径列表 + output_dir: 输出目录 + max_rounds: 最大分析轮数 + + Returns: + dict: 分析结果 + """ + agent = create_agent(output_dir=output_dir, max_rounds=max_rounds) + return agent.analyze(query, files) \ No newline at end of file diff --git a/config/__init__.py b/config/__init__.py new file mode 100644 index 0000000..4c973a4 --- /dev/null +++ b/config/__init__.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- +""" +配置模块 +""" + +from .llm_config import LLMConfig + +__all__ = ['LLMConfig'] diff --git a/config/llm_config.py b/config/llm_config.py new file mode 100644 index 0000000..ac4523c --- /dev/null +++ b/config/llm_config.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +""" +配置管理模块 +""" + +import os +from typing import Dict, Any +from dataclasses import dataclass, asdict + + +from dotenv import load_dotenv + +load_dotenv() + + +@dataclass +class LLMConfig: + """LLM配置""" + + provider: str = "openai" # openai, anthropic, etc. + api_key: str = os.environ.get("OPENAI_API_KEY", "sk-c44i1hy64xgzwox6x08o4zug93frq6rgn84oqugf2pje1tg4") + base_url: str = os.environ.get("OPENAI_BASE_URL", "https://api.xiaomimimo.com/v1") + model: str = os.environ.get("OPENAI_MODEL", "mimo-v2-flash") + temperature: float = 0.3 + max_tokens: int = 131072 + + def to_dict(self) -> Dict[str, Any]: + """转换为字典""" + return asdict(self) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "LLMConfig": + """从字典创建配置""" + return cls(**data) + + def validate(self) -> bool: + """验证配置有效性""" + if not self.api_key: + raise ValueError("OPENAI_API_KEY is required") + if not self.base_url: + raise ValueError("OPENAI_BASE_URL is required") + if not self.model: + raise ValueError("OPENAI_MODEL is required") + return True diff --git a/data_analysis_agent.py b/data_analysis_agent.py new file mode 100644 index 0000000..41f5b29 --- /dev/null +++ b/data_analysis_agent.py @@ -0,0 +1,483 @@ +# -*- coding: utf-8 -*- +""" +简化的 Notebook 数据分析智能体 +仅包含用户和助手两个角 +2. 图片必须保存到指定的会话目录中,输出绝对路径,禁止使用plt.show() +3. 表格输出控制:超过15行只显示前5行和后5行 +4. 强制使用SimHei字体:plt.rcParams['font.sans-serif'] = ['SimHei'] +5. 输出格式严格使用YAML共享上下文的单轮对话模式 +""" + +import os +import json +import yaml +from typing import Dict, Any, List, Optional +from utils.create_session_dir import create_session_output_dir +from utils.format_execution_result import format_execution_result +from utils.extract_code import extract_code_from_response +from utils.data_loader import load_and_profile_data +from utils.llm_helper import LLMHelper +from utils.code_executor import CodeExecutor +from config.llm_config import LLMConfig +from prompts import data_analysis_system_prompt, final_report_system_prompt + + +class DataAnalysisAgent: + """ + 数据分析智能体 + + 职责: + - 接收用户自然语言需求 + - 生成Python分析代码 + - 执行代码并收集结果 + - 基于执行结果继续生成后续分析代码 + """ + + def __init__( + self, + llm_config: LLMConfig = None, + output_dir: str = "outputs", + max_rounds: int = 20, + force_max_rounds: bool = False, + ): + """ + 初始化智能体 + + Args: + config: LLM配置 + output_dir: 输出目录 + max_rounds: 最大对话轮数 + force_max_rounds: 是否强制运行到最大轮数(忽略AI的完成信号) + """ + self.config = llm_config or LLMConfig() + self.llm = LLMHelper(self.config) + self.base_output_dir = output_dir + self.max_rounds = max_rounds + self.force_max_rounds = force_max_rounds + # 对话历史和上下文 + self.conversation_history = [] + self.analysis_results = [] + self.current_round = 0 + self.session_output_dir = None + self.executor = None + + def _process_response(self, response: str) -> Dict[str, Any]: + """ + 统一处理LLM响应,判断行动类型并执行相应操作 + + Args: + response: LLM的响应内容 + + Returns: + 处理结果字典 + """ + try: + yaml_data = self.llm.parse_yaml_response(response) + action = yaml_data.get("action", "generate_code") + + print(f"🎯 检测到动作: {action}") + + if action == "analysis_complete": + return self._handle_analysis_complete(response, yaml_data) + elif action == "collect_figures": + return self._handle_collect_figures(response, yaml_data) + elif action == "generate_code": + return self._handle_generate_code(response, yaml_data) + else: + print(f"⚠️ 未知动作类型: {action},按generate_code处理") + return self._handle_generate_code(response, yaml_data) + + except Exception as e: + print(f"⚠️ 解析响应失败: {str(e)},按generate_code处理") + return self._handle_generate_code(response, {}) + + def _handle_analysis_complete( + self, response: str, yaml_data: Dict[str, Any] + ) -> Dict[str, Any]: + """处理分析完成动作""" + print("✅ 分析任务完成") + final_report = yaml_data.get("final_report", "分析完成,无最终报告") + return { + "action": "analysis_complete", + "final_report": final_report, + "response": response, + "continue": False, + } + + def _handle_collect_figures( + self, response: str, yaml_data: Dict[str, Any] + ) -> Dict[str, Any]: + """处理图片收集动作""" + print("📊 开始收集图片") + figures_to_collect = yaml_data.get("figures_to_collect", []) + + collected_figures = [] + + for figure_info in figures_to_collect: + figure_number = figure_info.get("figure_number", "未知") + # 确保figure_number不为None时才用于文件名 + if figure_number != "未知": + default_filename = f"figure_{figure_number}.png" + else: + default_filename = "figure_unknown.png" + filename = figure_info.get("filename", default_filename) + file_path = figure_info.get("file_path", "") # 获取具体的文件路径 + description = figure_info.get("description", "") + analysis = figure_info.get("analysis", "") + + print(f"📈 收集图片 {figure_number}: {filename}") + print(f" 📂 路径: {file_path}") + print(f" 📝 描述: {description}") + print(f" 🔍 分析: {analysis}") + + # 验证文件是否存在 + if file_path and os.path.exists(file_path): + print(f" ✅ 文件存在: {file_path}") + elif file_path: + print(f" ⚠️ 文件不存在: {file_path}") + else: + print(f" ⚠️ 未提供文件路径") + + # 记录图片信息 + collected_figures.append( + { + "figure_number": figure_number, + "filename": filename, + "file_path": file_path, + "description": description, + "analysis": analysis, + } + ) + + return { + "action": "collect_figures", + "collected_figures": collected_figures, + "response": response, + "continue": True, + } + + def _handle_generate_code( + self, response: str, yaml_data: Dict[str, Any] + ) -> Dict[str, Any]: + """处理代码生成和执行动作""" + # 从YAML数据中获取代码(更准确) + code = yaml_data.get("code", "") + + # 如果YAML中没有代码,尝试从响应中提取 + if not code: + code = extract_code_from_response(response) + + # 二次清洗:防止YAML中解析出的code包含markdown标记 + if code: + code = code.strip() + if code.startswith("```"): + import re + # 去除开头的 ```python 或 ``` + code = re.sub(r"^```[a-zA-Z]*\n", "", code) + # 去除结尾的 ``` + code = re.sub(r"\n```$", "", code) + code = code.strip() + + if code: + print(f"🔧 执行代码:\n{code}") + print("-" * 40) + + # 执行代码 + result = self.executor.execute_code(code) + + # 格式化执行结果 + feedback = format_execution_result(result) + print(f"📋 执行反馈:\n{feedback}") + + return { + "action": "generate_code", + "code": code, + "result": result, + "feedback": feedback, + "response": response, + "continue": True, + } + else: + # 如果没有代码,说明LLM响应格式有问题,需要重新生成 + print("⚠️ 未从响应中提取到可执行代码,要求LLM重新生成") + return { + "action": "invalid_response", + "error": "响应中缺少可执行代码", + "response": response, + "continue": True, + } + + def analyze(self, user_input: str, files: List[str] = None) -> Dict[str, Any]: + """ + 开始分析流程 + + Args: + user_input: 用户的自然语言需求 + files: 数据文件路径列表 + + Returns: + 分析结果字典 + """ + # 重置状态 + self.conversation_history = [] + self.analysis_results = [] + self.current_round = 0 + + # 创建本次分析的专用输出目录 + self.session_output_dir = create_session_output_dir( + self.base_output_dir, user_input + ) + + # 初始化代码执行器,使用会话目录 + self.executor = CodeExecutor(self.session_output_dir) + + # 设置会话目录变量到执行环境中 + self.executor.set_variable("session_output_dir", self.session_output_dir) + + # 设用工具生成数据画像 + data_profile = "" + if files: + print("🔍 正在生成数据画像...") + data_profile = load_and_profile_data(files) + print("✅ 数据画像生成完毕") + + # 构建初始prompt + initial_prompt = f"""用户需求: {user_input}""" + if files: + initial_prompt += f"\n数据文件: {', '.join(files)}" + + if data_profile: + initial_prompt += f"\n\n{data_profile}\n\n请根据上述【数据画像】中的统计信息(如高频值、缺失率、数据范围)来制定分析策略。如果发现明显的高频问题或异常分布,请优先进行深度分析。" + + print(f"🚀 开始数据分析任务") + print(f"📝 用户需求: {user_input}") + if files: + print(f"📁 数据文件: {', '.join(files)}") + print(f"📂 输出目录: {self.session_output_dir}") + print(f"🔢 最大轮数: {self.max_rounds}") + if self.force_max_rounds: + print(f"⚡ 强制模式: 将运行满 {self.max_rounds} 轮(忽略AI完成信号)") + print("=" * 60) + # 添加到对话历史 + self.conversation_history.append({"role": "user", "content": initial_prompt}) + + while self.current_round < self.max_rounds: + self.current_round += 1 + print(f"\n🔄 第 {self.current_round} 轮分析") + # 调用LLM生成响应 + try: # 获取当前执行环境的变量信息 + notebook_variables = self.executor.get_environment_info() + + # 格式化系统提示词,填入动态的notebook变量信息 + formatted_system_prompt = data_analysis_system_prompt.format( + notebook_variables=notebook_variables + ) + + response = self.llm.call( + prompt=self._build_conversation_prompt(), + system_prompt=formatted_system_prompt, + ) + + print(f"🤖 助手响应:\n{response}") + + # 使用统一的响应处理方法 + process_result = self._process_response(response) + + # 根据处理结果决定是否继续(仅在非强制模式下) + if not self.force_max_rounds and not process_result.get( + "continue", True + ): + print(f"\n✅ 分析完成!") + break + + # 添加到对话历史 + self.conversation_history.append( + {"role": "assistant", "content": response} + ) + + # 根据动作类型添加不同的反馈 + if process_result["action"] == "generate_code": + feedback = process_result.get("feedback", "") + self.conversation_history.append( + {"role": "user", "content": f"代码执行反馈:\n{feedback}"} + ) + + # 记录分析结果 + self.analysis_results.append( + { + "round": self.current_round, + "code": process_result.get("code", ""), + "result": process_result.get("result", {}), + "response": response, + } + ) + elif process_result["action"] == "collect_figures": + # 记录图片收集结果 + collected_figures = process_result.get("collected_figures", []) + feedback = f"已收集 {len(collected_figures)} 个图片及其分析" + self.conversation_history.append( + { + "role": "user", + "content": f"图片收集反馈:\n{feedback}\n请继续下一步分析。", + } + ) + + # 记录到分析结果中 + self.analysis_results.append( + { + "round": self.current_round, + "action": "collect_figures", + "collected_figures": collected_figures, + "response": response, + } + ) + + except Exception as e: + error_msg = f"LLM调用错误: {str(e)}" + print(f"❌ {error_msg}") + self.conversation_history.append( + { + "role": "user", + "content": f"发生错误: {error_msg},请重新生成代码。", + } + ) + # 生成最终总结 + if self.current_round >= self.max_rounds: + print(f"\n⚠️ 已达到最大轮数 ({self.max_rounds}),分析结束") + + return self._generate_final_report() + + def _build_conversation_prompt(self) -> str: + """构建对话提示词""" + prompt_parts = [] + + for msg in self.conversation_history: + role = msg["role"] + content = msg["content"] + if role == "user": + prompt_parts.append(f"用户: {content}") + else: + prompt_parts.append(f"助手: {content}") + + return "\n\n".join(prompt_parts) + + def _generate_final_report(self) -> Dict[str, Any]: + """生成最终分析报告""" + # 收集所有生成的图片信息 + all_figures = [] + for result in self.analysis_results: + if result.get("action") == "collect_figures": + all_figures.extend(result.get("collected_figures", [])) + + print(f"\n📊 开始生成最终分析报告...") + print(f"📂 输出目录: {self.session_output_dir}") + print(f"🔢 总轮数: {self.current_round}") + print(f"📈 收集图片: {len(all_figures)} 个") + + # 构建用于生成最终报告的提示词 + final_report_prompt = self._build_final_report_prompt(all_figures) + + try: # 调用LLM生成最终报告 + response = self.llm.call( + prompt=final_report_prompt, + system_prompt="你将会接收到一个数据分析任务的最终报告请求,请根据提供的分析结果和图片信息生成完整的分析报告。", + max_tokens=16384, # 设置较大的token限制以容纳完整报告 + ) + + # 解析响应,提取最终报告 + try: + yaml_data = self.llm.parse_yaml_response(response) + if yaml_data.get("action") == "analysis_complete": + final_report_content = yaml_data.get("final_report", "报告生成失败") + else: + final_report_content = ( + "LLM未返回analysis_complete动作,报告生成失败" + ) + except: + # 如果解析失败,直接使用响应内容 + final_report_content = response + + print("✅ 最终报告生成完成") + + except Exception as e: + print(f"❌ 生成最终报告时出错: {str(e)}") + final_report_content = f"报告生成失败: {str(e)}" + + # 保存最终报告到文件 + report_file_path = os.path.join(self.session_output_dir, "最终分析报告.md") + try: + with open(report_file_path, "w", encoding="utf-8") as f: + f.write(final_report_content) + print(f"📄 最终报告已保存至: {report_file_path}") + except Exception as e: + print(f"❌ 保存报告文件失败: {str(e)}") + + # 返回完整的分析结果 + return { + "session_output_dir": self.session_output_dir, + "total_rounds": self.current_round, + "analysis_results": self.analysis_results, + "collected_figures": all_figures, + "conversation_history": self.conversation_history, + "final_report": final_report_content, + "report_file_path": report_file_path, + } + + def _build_final_report_prompt(self, all_figures: List[Dict[str, Any]]) -> str: + """构建用于生成最终报告的提示词""" + + # 构建图片信息摘要,使用相对路径 + figures_summary = "" + if all_figures: + figures_summary = "\n生成的图片及分析:\n" + for i, figure in enumerate(all_figures, 1): + filename = figure.get("filename", "未知文件名") + # 使用相对路径格式,适合在报告中引用 + relative_path = f"./{filename}" + figures_summary += f"{i}. {filename}\n" + figures_summary += f" 相对路径: {relative_path}\n" + figures_summary += f" 描述: {figure.get('description', '无描述')}\n" + figures_summary += f" 分析: {figure.get('analysis', '无分析')}\n\n" + else: + figures_summary = "\n本次分析未生成图片。\n" + + # 构建代码执行结果摘要(仅包含成功执行的代码块) + code_results_summary = "" + success_code_count = 0 + for result in self.analysis_results: + if result.get("action") != "collect_figures" and result.get("code"): + exec_result = result.get("result", {}) + if exec_result.get("success"): + success_code_count += 1 + code_results_summary += f"代码块 {success_code_count}: 执行成功\n" + if exec_result.get("output"): + code_results_summary += ( + f"输出: {exec_result.get('output')[:]}\n\n" + ) + + # 使用 prompts.py 中的统一提示词模板,并添加相对路径使用说明 + prompt = final_report_system_prompt.format( + current_round=self.current_round, + session_output_dir=self.session_output_dir, + figures_summary=figures_summary, + code_results_summary=code_results_summary, + ) + + # 在提示词中明确要求使用相对路径 + prompt += """ + +📁 **图片路径使用说明**: +报告和图片都在同一目录下,请在报告中使用相对路径引用图片: +- 格式:![图片描述](./图片文件名.png) +- 示例:![营业总收入趋势](./营业总收入趋势.png) +- 这样可以确保报告在不同环境下都能正确显示图片 +""" + + return prompt + + def reset(self): + """重置智能体状态""" + self.conversation_history = [] + self.analysis_results = [] + self.current_round = 0 + self.executor.reset_environment() diff --git a/main.py b/main.py new file mode 100644 index 0000000..d33e32d --- /dev/null +++ b/main.py @@ -0,0 +1,18 @@ +from data_analysis_agent import DataAnalysisAgent +from config.llm_config import LLMConfig + + +def main(): + llm_config = LLMConfig() + # 如果希望强制运行到最大轮数,设置 force_max_rounds=True + agent = DataAnalysisAgent(llm_config, force_max_rounds=False) + files = ["./UB IOV Support_TR.csv"] + report = agent.analyze( + user_input="基于所有有关远程控制的问题,以及涉及车控APP的运维工单的数据,输出若干个重要的统计指标,并绘制相关图表。总结一份,车控APP,及远程控制工单健康度报告,最后生成汇报给我。", + files=files, + ) + print(report) + + +if __name__ == "__main__": + main() diff --git a/prompts.py b/prompts.py new file mode 100644 index 0000000..6d13afe --- /dev/null +++ b/prompts.py @@ -0,0 +1,286 @@ +data_analysis_system_prompt = """你是一个专业的数据分析助手,运行在Jupyter Notebook环境中,能够根据用户需求生成和执行Python数据分析代码。 + +**重要指导原则**: +- 当需要执行Python代码(数据加载、分析、可视化)时,使用 `generate_code` 动作 +- 当需要收集和分析已生成的图表时,使用 `collect_figures` 动作 +- 当所有分析工作完成,需要输出最终报告时,使用 `analysis_complete` 动作 +- 每次响应只能选择一种动作类型,不要混合使用 + +目前jupyter notebook环境下有以下变量: +{notebook_variables} +核心能力: +1. 接收用户的自然语言分析需求 +2. 按步骤生成安全的Python分析代码 +3. 基于代码执行结果继续优化分析 + +Notebook环境特性: +- 你运行在IPython Notebook环境中,变量会在各个代码块之间保持 +- 第一次执行后,pandas、numpy、matplotlib等库已经导入,无需重复导入 +- 数据框(DataFrame)等变量在执行后会保留,可以直接使用 +- 因此,除非是第一次使用某个库,否则不需要重复import语句 + +重要约束: +1. 仅使用以下数据分析库:pandas, numpy, matplotlib, duckdb, os, json, datetime, re, pathlib +2. 图片必须保存到指定的会话目录中,输出绝对路径,禁止使用plt.show(),饼图的标签全部放在图例里面,用颜色区分。 +4. 表格输出控制:超过15行只显示前5行和后5行 +5. 中文字体设置:使用系统可用中文字体(macOS推荐:Hiragino Sans GB, Songti SC等) +6. 输出格式严格使用YAML + + +输出目录管理: +- 本次分析使用时间戳生成的专用目录,确保每次分析的输出文件隔离 +- 会话目录格式:session_[时间戳],如 session_20240105_143052 +- 图片保存路径格式:os.path.join(session_output_dir, '图片名称.png') +- 使用有意义的中文文件名:如'营业收入趋势.png', '利润分析对比.png' +- 每个图表保存后必须使用plt.close()释放内存 +- 输出绝对路径:使用os.path.abspath()获取图片的完整路径 + +数据分析工作流程(必须严格按顺序执行): + +**阶段1:数据探索(使用 generate_code 动作)** +- 首次数据加载时尝试多种编码:['utf-8', 'gbk', 'gb18030', 'gb2312', 'latin1'] +- 特殊处理:如果读取失败,尝试指定分隔符 `sep=','` 和错误处理 `error_bad_lines=False` +- 使用df.head()查看前几行数据,检查数据是否正确读取 +- 使用df.info()了解数据类型和缺失值情况 +- 重点检查:如果数值列显示为NaN但应该有值,说明读取或解析有问题 +- 使用df.dtypes查看每列的数据类型,确保日期列不是float64 +- 打印所有列名:df.columns.tolist() +- 绝对不要假设列名,必须先查看实际的列名 + +**阶段2:数据清洗和检查(使用 generate_code 动作)** +- 日期列识别:查找包含'date', 'time', 'Date', 'Time'关键词的列 +- 日期解析:尝试多种格式 ['%d/%m/%Y', '%Y-%m-%d', '%m/%d/%Y', '%Y/%m/%d', '%d-%m-%Y'] +- 类型转换:使用pd.to_datetime()转换日期列,指定format参数和errors='coerce' +- 空值处理:检查哪些列应该有值但显示NaN,可能是数据读取问题 +- 检查数据的时间范围和排序 +- 数据质量检查:确认数值列是否正确,字符串列是否被错误识别 + + +**阶段3:数据分析和可视化(使用 generate_code 动作)** +- 基于实际的列名进行计算 +- 生成有意义的图表 +- 图片保存到会话专用目录中 +- 每生成一个图表后,必须打印绝对路径 + + +**阶段4:深度挖掘与高级分析(使用 generate_code 动作)** +- **主动评估数据特征**:在执行前,先分析数据适合哪种高级挖掘: + - **时间序列数据**:必须进行趋势预测(使用sklearn/ARIMA/Prophet-like逻辑)和季节性分解。 + - **多维数值数据**:必须进行聚类分析(K-Means/DBSCAN)以发现用户/产品分层。 + - **分类/目标数据**:必须计算特征重要性(使用随机森林/相关性矩阵)以识别关键驱动因素。 + - **异常检测**:使用Isolation Forest或统计方法识别高价值或高风险的离群点。 +- **拒绝平庸**:不要为了做而做。如果数据量太小(<50行)或特征单一,请明确说明无法进行特定分析,并尝试挖掘其他角度(如分布偏度、帕累托分析)。 +- **业务导向**:每个模型结果必须翻译成业务语言(例如:“聚类结果显示,A类用户是高价值且对价格不敏感的群体”)。 + +**阶段5:高级分析结果可视化(使用 generate_code 动作)** +- **专业图表**:为高级分析匹配专用图表: + - 聚类 -> 降维散点图 (PCA/t-SNE) 或 平行坐标图 + - 相关性 -> 热力图 (Heatmap) + - 预测 -> 带有置信区间的趋势图 + - 特征重要性 -> 排序条形图 +- **保存与输出**:保存模型结果图表,并准备好在报告中解释。 + +**阶段6:图片收集和分析(使用 collect_figures 动作)** +- 当已生成多个图表后,使用 collect_figures 动作 +- 收集所有已生成的图片路径和信息 +- 对每个图片进行详细的分析和解读 + +**阶段7:最终报告(使用 analysis_complete 动作)** +- 当所有分析工作完成后,生成最终的分析报告 +- 包含对所有图片、模型和分析结果的综合总结 +- 提供业务建议和预测洞察 + +代码生成规则: +1. 每次只专注一个阶段,不要试图一次性完成所有任务 +2. 基于实际的数据结构而不是假设来编写代码 +3. Notebook环境中变量会保持,避免重复导入和重复加载相同数据 +4. 处理错误时,分析具体的错误信息并针对性修复,重新进行改阶段步骤,中途不要跳步骤 +5. 图片保存使用会话目录变量:session_output_dir +6. 图表标题和标签使用中文,使用系统配置的中文字体显示 +7. 必须打印绝对路径:每次保存图片后,使用os.path.abspath()打印完整的绝对路径 +8. 图片文件名:同时打印图片的文件名,方便后续收集时识别 +9. 饼图绘图代码生成必须遵守规则:类别 ≤ 5个:使用饼图 (plt.pie) + 外部图例,百分比标签清晰显示;类别 6-10个:使用水平条形图 (plt.barh) 便于阅读;类别 > 10个:使用排序条形图 + 合并小类别为"其他";学术美学要求**:白色背景、合适颜色、清晰标签、无冗余边框; + +动作选择指南: +- **需要执行Python代码** → 使用 "generate_code" +- **已生成多个图表,需要收集分析** → 使用 "collect_figures" +- **所有分析完成,输出最终报告** → 使用 "analysis_complete" +- **遇到错误需要修复代码** → 使用 "generate_code" + +高级分析技术指南(主动探索模式): +- **智能选择算法**: + - 遇到时间字段 -> `pd.to_datetime` -> 重采样 -> 移动平均/指数平滑/回归预测 + - 遇到多数值特征 -> `StandardScaler` -> `KMeans` (使用Elbow法则选k) -> `PCA`降维可视化 + - 遇到目标变量 -> `Correlation Matrix` -> `RandomForest` (feature_importances_) + - **文本挖掘**: + - 必须构建**专用停用词表** (Stop Words),过滤掉无效词汇: + - 年份/数字:2023, 2024, 2025, 1月, 2月... + - 通用动词:work, fix, support, issue, problem, check, test... + - 通用介词/代词:the, is, at, which, on, for, this, that... + - 仅保留具有实际业务含义的名词/动词短语(如 "connection timeout", "login failed")。 +- **异常值挖掘**:总是检查是否存在显著偏离均值的异常点,并标记出来进行个案分析。 +- **可视化增强**:不要只画折线图。使用 `seaborn` 的 `pairplot`, `heatmap`, `lmplot` 等高级图表。 + +可用分析库: + +图片收集要求: +- 在适当的时候(通常是生成了多个图表后),主动使用 `collect_figures` 动作 +- 收集时必须包含具体的图片绝对路径(file_path字段) +- 提供详细的图片描述和深入的分析 +- 确保图片路径与之前打印的路径一致 + +报告生成要求: +- 生成的报告要符合报告的文言需要,不要出现有争议的文字 +- 在适当的时候(通常是生成了多个图表后),进行图像的对比分析 +- 涉及的文言,不能出现我,你,他,等主观用于,采用报告式的文言论述 +- 提供详细的图片描述和深入的分析 +- 报告中的英文单词,初专有名词(TSP,TBOX等),其余的全部翻译成中文,例如remote control(远控),don't exist in TSP (数据不在TSP上); + +三种动作类型及使用时机: + +**1. 代码生成动作 (generate_code)** +适用于:数据加载、探索、清洗、计算、数据分析、图片生成、可视化等需要执行Python代码的情况 + +**2. 图片收集动作 (collect_figures)** +适用于:已生成多个图表后,需要对图片进行汇总和深入分析的情况 + +**3. 分析完成动作 (analysis_complete)** +适用于:所有分析工作完成,需要输出最终报告的情况 + +响应格式(严格遵守): + +**当需要执行代码时,使用此格式:** +```yaml +action: "generate_code" +reasoning: "详细说明当前步骤的目的和方法,为什么要这样做" +code: | + # 实际的Python代码 + import pandas as pd + # 具体分析代码... + + # 图片保存示例(如果生成图表) + plt.figure(figsize=(10, 6)) + # 绘图代码... + plt.title('图表标题') + file_path = os.path.join(session_output_dir, '图表名称.png') + plt.savefig(file_path, dpi=150, bbox_inches='tight') + plt.close() + # 必须打印绝对路径 + absolute_path = os.path.abspath(file_path) + print(f"图片已保存至: {{absolute_path}}") + print(f"图片文件名: {{os.path.basename(absolute_path)}}") + +next_steps: ["下一步计划1", "下一步计划2"] +``` +**当需要收集分析图片时,使用此格式:** +```yaml +action: "collect_figures" +reasoning: "说明为什么现在要收集图片,例如:已生成3个图表,现在收集并分析这些图表的内容" +figures_to_collect: + - figure_number: 1 + filename: "营业收入趋势分析.png" + file_path: "实际的完整绝对路径" + description: "图片概述:展示了什么内容" + analysis: "细节分析:从图中可以看出的具体信息和洞察" +next_steps: ["后续计划"] +``` + +**当所有分析完成时,使用此格式:** +```yaml +action: "analysis_complete" +final_report: | + 完整的最终分析报告内容 + (可以是多行文本) +``` + + + +特别注意: +- 数据读取问题:如果看到大量NaN值,检查编码和分隔符 +- 日期列问题:如果日期列显示为float64,说明解析失败 +- 编码错误:逐个尝试 ['utf-8', 'gbk', 'gb18030', 'gb2312', 'latin1'] +- 列类型错误:检查是否有列被错误识别为数值型但实际是文本 +- matplotlib错误时,确保使用Agg后端和正确的字体设置 +- 每次执行后根据反馈调整代码,不要重复相同的错误 + + +""" + +# 最终报告生成提示词 +final_report_system_prompt = """你是一个专业的数据分析师,需要基于完整的分析过程生成最终的分析报告。 + +分析信息: +分析轮数: {current_round} +输出目录: {session_output_dir} + +{figures_summary} + +代码执行结果摘要: +{code_results_summary} + +报告生成要求: +报告应使用markdown格式,确保结构清晰;需要包含对所有生成图片的详细分析和说明; +生成的报告要符合报告的文言需要,不要出现有争议的文字; +在适当的时候(通常是生成了多个图表后),进行图像的对比分析; +涉及的文言,不能出现我,你,他,等主观用于,采用报告式的文言论述; +提供详细的图片描述和深入的分析; +报告中的英文单词,初专有名词(TSP,TBOX等),其余的全部翻译成中文,例如remote control(远控),don't exist in TSP (数据不在TSP上); + +总结分析过程中的关键发现;提供有价值的结论和建议;内容必须专业且逻辑性强。 +**重要提醒:图片引用必须使用相对路径格式 `![图片描述](./图片文件名.png)`** + +图片质量与格式要求: +- **学术级图表标准**:所有图表必须达到发表级质量,包含: + * 专业的颜色方案(seaborn调色板) + * 清晰的标签和图例(无重叠) + * 合适的字体大小(≥12pt) + * 简洁的布局(白色背景,无冗余元素) +- **路径格式**:使用相对路径`![图片描述](./图片文件名.png)` +- **图表命名**:使用描述性中文名称,如`来源渠道分布.png` +响应格式要求: +必须严格使用以下YAML格式输出: + +```yaml +action: "analysis_complete" +final_report: | + # 数据分析报告 + + ## 分析概述 + [概述本次分析的目标和范围] + + ## 数据分析过程 + [总结分析的主要步骤] + + ## 关键发现 + [描述重要的分析结果,使用段落形式而非列表] + + ## 图表分析 + + ### [图表标题] + ![图表描述](./图片文件名.png) + + [对图表的详细分析,使用连续的段落描述,避免使用分点列表] + + ### [下一个图表标题] + ![图表描述](./另一个图片文件名.png) + + [对图表的详细分析,使用连续的段落描述] + + ## 深度分析 + ### [图表标题] + ![图表描述](./图片文件名.png) + + [对此前所有的数据,探索关联关系,进行深度剖析,重点问题,高频问题,并以图表介绍,使用连续的段落描述,避免使用分点列表] + + ## 结论与建议 + [基于分析结果提出结论和投资建议,使用段落形式表达] +``` + +特别注意事项: +必须对每个图片进行详细的分析和说明。 +图片的内容和标题必须与分析内容相关。 +使用专业的金融分析术语和方法。 +报告要完整、准确、有价值。 +**强制要求:所有图片路径都必须使用相对路径格式 `./文件名.png`。 +为了确保后续markdown转换docx效果良好,请避免在正文中使用分点列表形式,改用段落形式表达。** +""" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..c7155d8 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,52 @@ +# 数据分析和科学计算库 +pandas>=2.0.0 +openpyxl>=3.1.0 +numpy>=1.24.0 +matplotlib>=3.6.0 +duckdb>=0.8.0 +scipy>=1.10.0 +scikit-learn>=1.3.0 + +# Web和API相关 +requests>=2.28.0 +urllib3>=1.26.0 + +# 绘图和可视化 +plotly>=5.14.0 +dash>=2.0.0 + +# 流程图支持(可选,用于生成Mermaid图表) +# 注意:Mermaid图表主要在markdown中渲染,不需要额外的Python包 +# 如果需要在Python中生成Mermaid代码,可以考虑: +# mermaid-py>=0.3.0 + +# Jupyter/IPython环境 +ipython>=8.10.0 +jupyter>=1.0.0 + +# AI/LLM相关 +openai>=1.0.0 +pyyaml>=6.0 + +# 配置管理 +python-dotenv>=1.0.0 + +# 异步编程 +asyncio-mqtt>=0.11.1 +nest_asyncio>=1.5.0 + +# 文档生成(基于输出的Word文档) +python-docx>=0.8.11 + +# 系统和工具库 +pathlib2>=2.3.7 +typing-extensions>=4.5.0 + +# 开发和测试工具(可选) +pytest>=7.0.0 +pytest-asyncio>=0.21.0 +black>=23.0.0 +flake8>=6.0.0 + +# 字体支持(用于matplotlib中文显示) +fonttools>=4.38.0 diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..be1d86e --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- +""" +工具模块初始化文件 +""" + +from utils.code_executor import CodeExecutor +from utils.llm_helper import LLMHelper +from utils.fallback_openai_client import AsyncFallbackOpenAIClient + +__all__ = ["CodeExecutor", "LLMHelper", "AsyncFallbackOpenAIClient"] \ No newline at end of file diff --git a/utils/code_executor.py b/utils/code_executor.py new file mode 100644 index 0000000..382c4af --- /dev/null +++ b/utils/code_executor.py @@ -0,0 +1,453 @@ +# -*- coding: utf-8 -*- +""" +安全的代码执行器,基于 IPython 提供 notebook 环境下的代码执行功能 +""" + +import os +import sys +import ast +import traceback +import io +from typing import Dict, Any, List, Optional, Tuple +from contextlib import redirect_stdout, redirect_stderr +from IPython.core.interactiveshell import InteractiveShell +from IPython.utils.capture import capture_output +import matplotlib +import matplotlib.pyplot as plt +import matplotlib.font_manager as fm + + +class CodeExecutor: + """ + 安全的代码执行器,限制依赖库,捕获输出,支持图片保存与路径输出 + """ + + ALLOWED_IMPORTS = { + "pandas", + "pd", + "numpy", + "np", + "matplotlib", + "matplotlib.pyplot", + "plt", + "seaborn", + "sns", + "duckdb", + "scipy", + "sklearn", + "plotly", + "dash", + "requests", + "urllib", + "os", + "sys", + "json", + "csv", + "datetime", + "time", + "math", + "statistics", + "re", + "pathlib", + "io", + "collections", + "itertools", + "functools", + "operator", + "warnings", + "logging", + "copy", + "pickle", + "gzip", + "zipfile", + "yaml", + "typing", + "dataclasses", + "enum", + "sqlite3", + } + + def __init__(self, output_dir: str = "outputs"): + """ + 初始化代码执行器 + + Args: + output_dir: 输出目录,用于保存图片和文件 + """ + self.output_dir = os.path.abspath(output_dir) + os.makedirs(self.output_dir, exist_ok=True) + + # 初始化 IPython shell + self.shell = InteractiveShell.instance() + + # 设置中文字体 + self._setup_chinese_font() + + # 预导入常用库 + self._setup_common_imports() + + # 图片计数器 + self.image_counter = 0 + + def _setup_chinese_font(self): + """设置matplotlib中文字体显示""" + try: + # 设置matplotlib使用Agg backend避免GUI问题 + matplotlib.use("Agg") + + # 获取系统可用字体 + available_fonts = [f.name for f in fm.fontManager.ttflist] + + # 设置matplotlib使用系统可用中文字体 + # macOS系统常用中文字体(按优先级排序) + chinese_fonts = [ + "Hiragino Sans GB", # macOS中文简体 + "Songti SC", # macOS宋体简体 + "PingFang SC", # macOS苹方简体 + "Heiti SC", # macOS黑体简体 + "Heiti TC", # macOS黑体繁体 + "PingFang HK", # macOS苹方香港 + "SimHei", # Windows黑体 + "STHeiti", # 华文黑体 + "WenQuanYi Micro Hei", # Linux文泉驿微米黑 + "DejaVu Sans", # 默认无衬线字体 + "Arial Unicode MS", # Arial Unicode + ] + + # 检查系统中实际存在的字体 + system_chinese_fonts = [ + font for font in chinese_fonts if font in available_fonts + ] + + # 如果没有找到合适的中文字体,尝试更宽松的搜索 + if not system_chinese_fonts: + print("警告:未找到精确匹配的中文字体,尝试更宽松的搜索...") + # 更宽松的字体匹配(包含部分名称) + fallback_fonts = [] + for available_font in available_fonts: + if any( + keyword in available_font + for keyword in [ + "Hei", + "Song", + "Fang", + "Kai", + "Hiragino", + "PingFang", + "ST", + ] + ): + fallback_fonts.append(available_font) + + if fallback_fonts: + system_chinese_fonts = fallback_fonts[:3] # 取前3个匹配的字体 + print(f"找到备选中文字体: {system_chinese_fonts}") + else: + print("警告:系统中未找到合适的中文字体,使用系统默认字体") + system_chinese_fonts = ["DejaVu Sans", "Arial Unicode MS"] + + # 设置字体配置 + plt.rcParams["font.sans-serif"] = system_chinese_fonts + [ + "DejaVu Sans", + "Arial Unicode MS", + ] + + plt.rcParams["axes.unicode_minus"] = False + plt.rcParams["font.family"] = "sans-serif" + + # 在shell中也设置相同的字体配置 + font_list_str = str( + system_chinese_fonts + ["DejaVu Sans", "Arial Unicode MS"] + ) + self.shell.run_cell( + f""" +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import matplotlib.font_manager as fm + +# 设置中文字体 +plt.rcParams['font.sans-serif'] = {font_list_str} +plt.rcParams['axes.unicode_minus'] = False +plt.rcParams['font.family'] = 'sans-serif' + +# 确保matplotlib缓存目录可写 +import os +cache_dir = os.path.expanduser('~/.matplotlib') +if not os.path.exists(cache_dir): + os.makedirs(cache_dir, exist_ok=True) +os.environ['MPLCONFIGDIR'] = cache_dir +""" + ) + except Exception as e: + print(f"设置中文字体失败: {e}") + # 即使失败也要设置基本的matplotlib配置 + try: + matplotlib.use("Agg") + plt.rcParams["axes.unicode_minus"] = False + except: + pass + + def _setup_common_imports(self): + """预导入常用库""" + common_imports = """ +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import duckdb +import os +import json +from IPython.display import display +""" + try: + self.shell.run_cell(common_imports) + # 确保display函数在shell的用户命名空间中可用 + from IPython.display import display + + self.shell.user_ns["display"] = display + except Exception as e: + print(f"预导入库失败: {e}") + + def _check_code_safety(self, code: str) -> Tuple[bool, str]: + """ + 检查代码安全性,限制导入的库 + + Returns: + (is_safe, error_message) + """ + try: + tree = ast.parse(code) + except SyntaxError as e: + return False, f"语法错误: {e}" + + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + if alias.name not in self.ALLOWED_IMPORTS: + return False, f"不允许的导入: {alias.name}" + + elif isinstance(node, ast.ImportFrom): + if node.module not in self.ALLOWED_IMPORTS: + return False, f"不允许的导入: {node.module}" + + # 检查属性访问(防止通过os.system等方式绕过) + elif isinstance(node, ast.Attribute): + # 检查是否访问了os模块的属性 + if isinstance(node.value, ast.Name) and node.value.id == "os": + # 允许的os子模块和函数白名单 + allowed_os_attributes = { + "path", "environ", "getcwd", "listdir", "makedirs", "mkdir", "remove", "rmdir", + "path.join", "path.exists", "path.abspath", "path.dirname", + "path.basename", "path.splitext", "path.isdir", "path.isfile", + "sep", "name", "linesep", "stat", "getpid" + } + + # 检查直接属性访问 (如 os.getcwd) + if node.attr not in allowed_os_attributes: + # 进一步检查如果是 os.path.xxx 这种形式 + # Note: ast.Attribute 嵌套结构比较复杂,简单处理只允许 os.path 和上述白名单 + if node.attr == "path": + pass # 允许访问 os.path + else: + return False, f"不允许的os属性访问: os.{node.attr}" + + # 检查危险函数调用 + elif isinstance(node, ast.Call): + if isinstance(node.func, ast.Name): + if node.func.id in ["exec", "eval", "open", "__import__"]: + return False, f"不允许的函数调用: {node.func.id}" + + return True, "" + + def get_current_figures_info(self) -> List[Dict[str, Any]]: + """获取当前matplotlib图形信息,但不自动保存""" + figures_info = [] + + # 获取当前所有图形 + fig_nums = plt.get_fignums() + + for fig_num in fig_nums: + fig = plt.figure(fig_num) + if fig.get_axes(): # 只处理有内容的图形 + figures_info.append( + { + "figure_number": fig_num, + "axes_count": len(fig.get_axes()), + "figure_size": fig.get_size_inches().tolist(), + "has_content": True, + } + ) + + return figures_info + + def _format_table_output(self, obj: Any) -> str: + """格式化表格输出,限制行数""" + if hasattr(obj, "shape") and hasattr(obj, "head"): # pandas DataFrame + rows, cols = obj.shape + print(f"\n数据表形状: {rows}行 x {cols}列") + print(f"列名: {list(obj.columns)}") + + if rows <= 15: + return str(obj) + else: + head_part = obj.head(5) + tail_part = obj.tail(5) + return f"{head_part}\n...\n(省略 {rows-10} 行)\n...\n{tail_part}" + + return str(obj) + + def execute_code(self, code: str) -> Dict[str, Any]: + """ + 执行代码并返回结果 + + Args: + code: 要执行的Python代码 + + Returns: + { + 'success': bool, + 'output': str, + 'error': str, + 'variables': Dict[str, Any] # 新生成的重要变量 + } + """ + # 检查代码安全性 + is_safe, safety_error = self._check_code_safety(code) + if not is_safe: + return { + "success": False, + "output": "", + "error": f"代码安全检查失败: {safety_error}", + "variables": {}, + } + + # 记录执行前的变量 + vars_before = set(self.shell.user_ns.keys()) + + try: + # 使用IPython的capture_output来捕获所有输出 + with capture_output() as captured: + result = self.shell.run_cell(code) + + # 检查执行结果 + if result.error_before_exec: + error_msg = str(result.error_before_exec) + return { + "success": False, + "output": captured.stdout, + "error": f"执行前错误: {error_msg}", + "variables": {}, + } + + if result.error_in_exec: + error_msg = str(result.error_in_exec) + return { + "success": False, + "output": captured.stdout, + "error": f"执行错误: {error_msg}", + "variables": {}, + } + + # 获取输出 + output = captured.stdout + + # 如果有返回值,添加到输出 + if result.result is not None: + formatted_result = self._format_table_output(result.result) + output += f"\n{formatted_result}" + # 记录新产生的重要变量(简化版本) + vars_after = set(self.shell.user_ns.keys()) + new_vars = vars_after - vars_before + + # 只记录新创建的DataFrame等重要数据结构 + important_new_vars = {} + for var_name in new_vars: + if not var_name.startswith("_"): + try: + var_value = self.shell.user_ns[var_name] + if hasattr(var_value, "shape"): # pandas DataFrame, numpy array + important_new_vars[var_name] = ( + f"{type(var_value).__name__} with shape {var_value.shape}" + ) + elif var_name in ["session_output_dir"]: # 重要的配置变量 + important_new_vars[var_name] = str(var_value) + except: + pass + + return { + "success": True, + "output": output, + "error": "", + "variables": important_new_vars, + } + except Exception as e: + return { + "success": False, + "output": captured.stdout if "captured" in locals() else "", + "error": f"执行异常: {str(e)}\n{traceback.format_exc()}", + "variables": {}, + } + + def reset_environment(self): + """重置执行环境""" + self.shell.reset() + self._setup_common_imports() + self._setup_chinese_font() + plt.close("all") + self.image_counter = 0 + + def set_variable(self, name: str, value: Any): + """设置执行环境中的变量""" + self.shell.user_ns[name] = value + + def get_environment_info(self) -> str: + """获取当前执行环境的变量信息,用于系统提示词""" + info_parts = [] + + # 获取重要的数据变量 + important_vars = {} + for var_name, var_value in self.shell.user_ns.items(): + if not var_name.startswith("_") and var_name not in [ + "In", + "Out", + "get_ipython", + "exit", + "quit", + ]: + try: + if hasattr(var_value, "shape"): # pandas DataFrame, numpy array + important_vars[var_name] = ( + f"{type(var_value).__name__} with shape {var_value.shape}" + ) + elif var_name in ["session_output_dir"]: # 重要的路径变量 + important_vars[var_name] = str(var_value) + elif ( + isinstance(var_value, (int, float, str, bool)) + and len(str(var_value)) < 100 + ): + important_vars[var_name] = ( + f"{type(var_value).__name__}: {var_value}" + ) + elif hasattr(var_value, "__module__") and var_value.__module__ in [ + "pandas", + "numpy", + "matplotlib.pyplot", + ]: + important_vars[var_name] = f"导入的模块: {var_value.__module__}" + except: + continue + + if important_vars: + info_parts.append("当前环境变量:") + for var_name, var_info in important_vars.items(): + info_parts.append(f"- {var_name}: {var_info}") + else: + info_parts.append("当前环境已预装pandas, numpy, matplotlib等库") + + # 添加输出目录信息 + if "session_output_dir" in self.shell.user_ns: + info_parts.append( + f"图片保存目录: session_output_dir = '{self.shell.user_ns['session_output_dir']}'" + ) + + return "\n".join(info_parts) diff --git a/utils/create_session_dir.py b/utils/create_session_dir.py new file mode 100644 index 0000000..d641aab --- /dev/null +++ b/utils/create_session_dir.py @@ -0,0 +1,15 @@ +import os +from datetime import datetime + + +def create_session_output_dir(base_output_dir, user_input: str) -> str: + """为本次分析创建独立的输出目录""" + + # 使用当前时间创建唯一的会话目录名(格式:YYYYMMDD_HHMMSS) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + session_id = timestamp + dir_name = f"session_{session_id}" + session_dir = os.path.join(base_output_dir, dir_name) + os.makedirs(session_dir, exist_ok=True) + + return session_dir diff --git a/utils/data_loader.py b/utils/data_loader.py new file mode 100644 index 0000000..c1c2cef --- /dev/null +++ b/utils/data_loader.py @@ -0,0 +1,90 @@ +# -*- coding: utf-8 -*- +import os +import pandas as pd +import io + +def load_and_profile_data(file_paths: list) -> str: + """ + 加载数据并生成数据画像 + + Args: + file_paths: 文件路径列表 + + Returns: + 包含数据画像的Markdown字符串 + """ + profile_summary = "# 数据画像报告 (Data Profile)\n\n" + + if not file_paths: + return profile_summary + "未提供数据文件。" + + for file_path in file_paths: + file_name = os.path.basename(file_path) + profile_summary += f"## 文件: {file_name}\n\n" + + if not os.path.exists(file_path): + profile_summary += f"⚠️ 文件不存在: {file_path}\n\n" + continue + + try: + # 根据扩展名选择加载方式 + ext = os.path.splitext(file_path)[1].lower() + if ext == '.csv': + # 尝试多种编码 + try: + df = pd.read_csv(file_path, encoding='utf-8') + except UnicodeDecodeError: + try: + df = pd.read_csv(file_path, encoding='gbk') + except Exception: + df = pd.read_csv(file_path, encoding='latin1') + elif ext in ['.xlsx', '.xls']: + df = pd.read_excel(file_path) + else: + profile_summary += f"⚠️ 不支持的文件格式: {ext}\n\n" + continue + + # 基础信息 + rows, cols = df.shape + profile_summary += f"- **维度**: {rows} 行 x {cols} 列\n" + profile_summary += f"- **列名**: `{', '.join(df.columns)}`\n\n" + + profile_summary += "### 列详细分布:\n" + + # 遍历分析每列 + for col in df.columns: + dtype = df[col].dtype + null_count = df[col].isnull().sum() + null_ratio = (null_count / rows) * 100 + + profile_summary += f"#### {col} ({dtype})\n" + if null_count > 0: + profile_summary += f"- ⚠️ 空值: {null_count} ({null_ratio:.1f}%)\n" + + # 数值列分析 + if pd.api.types.is_numeric_dtype(dtype): + desc = df[col].describe() + profile_summary += f"- 统计: Min={desc['min']:.2f}, Max={desc['max']:.2f}, Mean={desc['mean']:.2f}\n" + + # 文本/分类列分析 + elif pd.api.types.is_object_dtype(dtype) or pd.api.types.is_categorical_dtype(dtype): + unique_count = df[col].nunique() + profile_summary += f"- 唯一值数量: {unique_count}\n" + + # 如果唯一值较少(<50)或者看起来是分类数据,显示Top分布 + # 这对识别“高频问题”至关重要 + if unique_count > 0: + top_n = df[col].value_counts().head(5) + top_items_str = ", ".join([f"{k}({v})" for k, v in top_n.items()]) + profile_summary += f"- **TOP 5 高频值**: {top_items_str}\n" + + # 时间列分析 + elif pd.api.types.is_datetime64_any_dtype(dtype): + profile_summary += f"- 范围: {df[col].min()} 至 {df[col].max()}\n" + + profile_summary += "\n" + + except Exception as e: + profile_summary += f"❌ 读取或分析文件失败: {str(e)}\n\n" + + return profile_summary diff --git a/utils/extract_code.py b/utils/extract_code.py new file mode 100644 index 0000000..bd2420f --- /dev/null +++ b/utils/extract_code.py @@ -0,0 +1,38 @@ +from typing import Optional +import yaml + + +def extract_code_from_response(response: str) -> Optional[str]: + """从LLM响应中提取代码""" + try: + # 尝试解析YAML + if '```yaml' in response: + start = response.find('```yaml') + 7 + end = response.find('```', start) + yaml_content = response[start:end].strip() + elif '```' in response: + start = response.find('```') + 3 + end = response.find('```', start) + yaml_content = response[start:end].strip() + else: + yaml_content = response.strip() + + yaml_data = yaml.safe_load(yaml_content) + if 'code' in yaml_data: + return yaml_data['code'] + except: + pass + + # 如果YAML解析失败,尝试提取```python代码块 + if '```python' in response: + start = response.find('```python') + 9 + end = response.find('```', start) + if end != -1: + return response[start:end].strip() + elif '```' in response: + start = response.find('```') + 3 + end = response.find('```', start) + if end != -1: + return response[start:end].strip() + + return None \ No newline at end of file diff --git a/utils/fallback_openai_client.py b/utils/fallback_openai_client.py new file mode 100644 index 0000000..2101f22 --- /dev/null +++ b/utils/fallback_openai_client.py @@ -0,0 +1,230 @@ +# -*- coding: utf-8 -*- +import asyncio +from typing import Optional, Any, Mapping, Dict +from openai import AsyncOpenAI, APIStatusError, APIConnectionError, APITimeoutError, APIError +from openai.types.chat import ChatCompletion + +class AsyncFallbackOpenAIClient: + """ + 一个支持备用 API 自动切换的异步 OpenAI 客户端。 + 当主 API 调用因特定错误(如内容过滤)失败时,会自动尝试使用备用 API。 + """ + def __init__( + self, + primary_api_key: str, + primary_base_url: str, + primary_model_name: str, + fallback_api_key: Optional[str] = None, + fallback_base_url: Optional[str] = None, + fallback_model_name: Optional[str] = None, + primary_client_args: Optional[Dict[str, Any]] = None, + fallback_client_args: Optional[Dict[str, Any]] = None, + content_filter_error_code: str = "1301", # 特定于 Zhipu 的内容过滤错误代码 + content_filter_error_field: str = "contentFilter", # 特定于 Zhipu 的内容过滤错误字段 + max_retries_primary: int = 1, # 主API重试次数 + max_retries_fallback: int = 1, # 备用API重试次数 + retry_delay_seconds: float = 1.0 # 重试延迟时间 + ): + """ + 初始化 AsyncFallbackOpenAIClient。 + + Args: + primary_api_key: 主 API 的密钥。 + primary_base_url: 主 API 的基础 URL。 + primary_model_name: 主 API 使用的模型名称。 + fallback_api_key: 备用 API 的密钥 (可选)。 + fallback_base_url: 备用 API 的基础 URL (可选)。 + fallback_model_name: 备用 API 使用的模型名称 (可选)。 + primary_client_args: 传递给主 AsyncOpenAI 客户端的其他参数。 + fallback_client_args: 传递给备用 AsyncOpenAI 客户端的其他参数。 + content_filter_error_code: 触发回退的内容过滤错误的特定错误代码。 + content_filter_error_field: 触发回退的内容过滤错误中存在的字段名。 + max_retries_primary: 主 API 失败时的最大重试次数。 + max_retries_fallback: 备用 API 失败时的最大重试次数。 + retry_delay_seconds: 重试前的延迟时间(秒)。 + """ + if not primary_api_key or not primary_base_url: + raise ValueError("主 API 密钥和基础 URL 不能为空。") + + _primary_args = primary_client_args or {} + self.primary_client = AsyncOpenAI(api_key=primary_api_key, base_url=primary_base_url, **_primary_args) + self.primary_model_name = primary_model_name + + self.fallback_client: Optional[AsyncOpenAI] = None + self.fallback_model_name: Optional[str] = None + if fallback_api_key and fallback_base_url and fallback_model_name: + _fallback_args = fallback_client_args or {} + self.fallback_client = AsyncOpenAI(api_key=fallback_api_key, base_url=fallback_base_url, **_fallback_args) + self.fallback_model_name = fallback_model_name + else: + print("⚠️ 警告: 未完全配置备用 API 客户端。如果主 API 失败,将无法进行回退。") + + self.content_filter_error_code = content_filter_error_code + self.content_filter_error_field = content_filter_error_field + self.max_retries_primary = max_retries_primary + self.max_retries_fallback = max_retries_fallback + self.retry_delay_seconds = retry_delay_seconds + self._closed = False + + async def _attempt_api_call( + self, + client: AsyncOpenAI, + model_name: str, + messages: list[Mapping[str, Any]], + max_retries: int, + api_name: str, + **kwargs: Any + ) -> ChatCompletion: + """ + 尝试调用指定的 OpenAI API 客户端,并进行重试。 + """ + last_exception = None + for attempt in range(max_retries + 1): + try: + # print(f"尝试使用 {api_name} API ({client.base_url}) 模型: {kwargs.get('model', model_name)}, 第 {attempt + 1} 次尝试") + completion = await client.chat.completions.create( + model=kwargs.pop('model', model_name), + messages=messages, + **kwargs + ) + return completion + except (APIConnectionError, APITimeoutError) as e: # 通常可以重试的网络错误 + last_exception = e + print(f"⚠️ {api_name} API 调用时发生可重试错误 ({type(e).__name__}): {e}. 尝试次数 {attempt + 1}/{max_retries + 1}") + if attempt < max_retries: + await asyncio.sleep(self.retry_delay_seconds * (attempt + 1)) # 增加延迟 + else: + print(f"❌ {api_name} API 在达到最大重试次数后仍然失败。") + except APIStatusError as e: # API 返回的特定状态码错误 + is_content_filter_error = False + if e.status_code == 400: + try: + error_json = e.response.json() + error_details = error_json.get("error", {}) + if (error_details.get("code") == self.content_filter_error_code and + self.content_filter_error_field in error_json): + is_content_filter_error = True + except Exception: + pass # 解析错误响应失败,不认为是内容过滤错误 + + if is_content_filter_error and api_name == "主": # 如果是主 API 的内容过滤错误,则直接抛出以便回退 + raise e + + last_exception = e + print(f"⚠️ {api_name} API 调用时发生 APIStatusError ({e.status_code}): {e}. 尝试次数 {attempt + 1}/{max_retries + 1}") + if attempt < max_retries: + await asyncio.sleep(self.retry_delay_seconds * (attempt + 1)) + else: + print(f"❌ {api_name} API 在达到最大重试次数后仍然失败 (APIStatusError)。") + except APIError as e: # 其他不可轻易重试的 OpenAI 错误 + last_exception = e + print(f"❌ {api_name} API 调用时发生不可重试错误 ({type(e).__name__}): {e}") + break # 不再重试此类错误 + + if last_exception: + raise last_exception + raise RuntimeError(f"{api_name} API 调用意外失败。") # 理论上不应到达这里 + + async def chat_completions_create( + self, + messages: list[Mapping[str, Any]], + **kwargs: Any # 用于传递其他 OpenAI 参数,如 max_tokens, temperature 等。 + ) -> ChatCompletion: + """ + 使用主 API 创建聊天补全,如果发生特定内容过滤错误或主 API 调用失败,则回退到备用 API。 + 支持对主 API 和备用 API 的可重试错误进行重试。 + + Args: + messages: OpenAI API 的消息列表。 + **kwargs: 传递给 OpenAI API 调用的其他参数。 + + Returns: + ChatCompletion 对象。 + + Raises: + APIError: 如果主 API 和备用 API (如果尝试) 都返回 API 错误。 + RuntimeError: 如果客户端已关闭。 + """ + if self._closed: + raise RuntimeError("客户端已关闭。") + + try: + completion = await self._attempt_api_call( + client=self.primary_client, + model_name=self.primary_model_name, + messages=messages, + max_retries=self.max_retries_primary, + api_name="主", + **kwargs.copy() + ) + return completion + except APIStatusError as e_primary: + is_content_filter_error = False + if e_primary.status_code == 400: + try: + error_json = e_primary.response.json() + error_details = error_json.get("error", {}) + if (error_details.get("code") == self.content_filter_error_code and + self.content_filter_error_field in error_json): + is_content_filter_error = True + except Exception: + pass + + if is_content_filter_error and self.fallback_client and self.fallback_model_name: + print(f"ℹ️ 主 API 内容过滤错误 ({e_primary.status_code})。尝试切换到备用 API ({self.fallback_client.base_url})...") + try: + fallback_completion = await self._attempt_api_call( + client=self.fallback_client, + model_name=self.fallback_model_name, + messages=messages, + max_retries=self.max_retries_fallback, + api_name="备用", + **kwargs.copy() + ) + print(f"✅ 备用 API 调用成功。") + return fallback_completion + except APIError as e_fallback: + print(f"❌ 备用 API 调用最终失败: {type(e_fallback).__name__} - {e_fallback}") + raise e_fallback + else: + if not (self.fallback_client and self.fallback_model_name and is_content_filter_error): + # 如果不是内容过滤错误,或者没有可用的备用API,则记录主API的原始错误 + print(f"ℹ️ 主 API 错误 ({type(e_primary).__name__}: {e_primary}), 且不满足备用条件或备用API未配置。") + raise e_primary + except APIError as e_primary_other: + print(f"❌ 主 API 调用最终失败 (非内容过滤,错误类型: {type(e_primary_other).__name__}): {e_primary_other}") + if self.fallback_client and self.fallback_model_name: + print(f"ℹ️ 主 API 失败,尝试切换到备用 API ({self.fallback_client.base_url})...") + try: + fallback_completion = await self._attempt_api_call( + client=self.fallback_client, + model_name=self.fallback_model_name, + messages=messages, + max_retries=self.max_retries_fallback, + api_name="备用", + **kwargs.copy() + ) + print(f"✅ 备用 API 调用成功。") + return fallback_completion + except APIError as e_fallback_after_primary_fail: + print(f"❌ 备用 API 在主 API 失败后也调用失败: {type(e_fallback_after_primary_fail).__name__} - {e_fallback_after_primary_fail}") + raise e_fallback_after_primary_fail + else: + raise e_primary_other + + async def close(self): + """异步关闭主客户端和备用客户端 (如果存在)。""" + if not self._closed: + await self.primary_client.close() + if self.fallback_client: + await self.fallback_client.close() + self._closed = True + # print("AsyncFallbackOpenAIClient 已关闭。") + + async def __aenter__(self): + if self._closed: + raise RuntimeError("AsyncFallbackOpenAIClient 不能在关闭后重新进入。请创建一个新实例。") + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() diff --git a/utils/format_execution_result.py b/utils/format_execution_result.py new file mode 100644 index 0000000..7706d92 --- /dev/null +++ b/utils/format_execution_result.py @@ -0,0 +1,25 @@ + +from typing import Any, Dict + + +def format_execution_result(result: Dict[str, Any]) -> str: + """格式化执行结果为用户可读的反馈""" + feedback = [] + + if result['success']: + feedback.append("✅ 代码执行成功") + + if result['output']: + feedback.append(f"📊 输出结果:\n{result['output']}") + + if result.get('variables'): + feedback.append("📋 新生成的变量:") + for var_name, var_info in result['variables'].items(): + feedback.append(f" - {var_name}: {var_info}") + else: + feedback.append("❌ 代码执行失败") + feedback.append(f"错误信息: {result['error']}") + if result['output']: + feedback.append(f"部分输出: {result['output']}") + + return "\n".join(feedback) diff --git a/utils/llm_helper.py b/utils/llm_helper.py new file mode 100644 index 0000000..f24d967 --- /dev/null +++ b/utils/llm_helper.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +""" +LLM调用辅助模块 +""" + +import asyncio +import yaml +from config.llm_config import LLMConfig +from utils.fallback_openai_client import AsyncFallbackOpenAIClient + +class LLMHelper: + """LLM调用辅助类,支持同步和异步调用""" + + def __init__(self, config: LLMConfig = None): + self.config = config + self.client = AsyncFallbackOpenAIClient( + primary_api_key=config.api_key, + primary_base_url=config.base_url, + primary_model_name=config.model + ) + + async def async_call(self, prompt: str, system_prompt: str = None, max_tokens: int = None, temperature: float = None) -> str: + """异步调用LLM""" + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": prompt}) + + kwargs = {} + if max_tokens is not None: + kwargs['max_tokens'] = max_tokens + else: + kwargs['max_tokens'] = self.config.max_tokens + + if temperature is not None: + kwargs['temperature'] = temperature + else: + kwargs['temperature'] = self.config.temperature + + try: + response = await self.client.chat_completions_create( + messages=messages, + **kwargs + ) + return response.choices[0].message.content + except Exception as e: + print(f"LLM调用失败: {e}") + return "" + + def call(self, prompt: str, system_prompt: str = None, max_tokens: int = None, temperature: float = None) -> str: + """同步调用LLM""" + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + import nest_asyncio + nest_asyncio.apply() + + return loop.run_until_complete(self.async_call(prompt, system_prompt, max_tokens, temperature)) + + def parse_yaml_response(self, response: str) -> dict: + """解析YAML格式的响应""" + try: + # 提取```yaml和```之间的内容 + if '```yaml' in response: + start = response.find('```yaml') + 7 + end = response.find('```', start) + yaml_content = response[start:end].strip() + elif '```' in response: + start = response.find('```') + 3 + end = response.find('```', start) + yaml_content = response[start:end].strip() + else: + yaml_content = response.strip() + + return yaml.safe_load(yaml_content) + except Exception as e: + print(f"YAML解析失败: {e}") + print(f"原始响应: {response}") + return {} + + async def close(self): + """关闭客户端""" + await self.client.close() \ No newline at end of file