530 lines
21 KiB
Python
530 lines
21 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
简化的 Notebook 数据分析智能体
|
||
仅包含用户和助手两个角
|
||
2. 图片必须保存到指定的会话目录中,输出绝对路径,禁止使用plt.show()
|
||
3. 表格输出控制:超过15行只显示前5行和后5行
|
||
4. 强制使用SimHei字体:plt.rcParams['font.sans-serif'] = ['SimHei']
|
||
5. 输出格式严格使用YAML共享上下文的单轮对话模式
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
import yaml
|
||
from typing import Dict, Any, List, Optional
|
||
from utils.create_session_dir import create_session_output_dir
|
||
from utils.format_execution_result import format_execution_result
|
||
from utils.extract_code import extract_code_from_response
|
||
from utils.data_loader import load_and_profile_data
|
||
from utils.llm_helper import LLMHelper
|
||
from utils.code_executor import CodeExecutor
|
||
from config.llm_config import LLMConfig
|
||
from prompts import data_analysis_system_prompt, final_report_system_prompt
|
||
|
||
|
||
class DataAnalysisAgent:
|
||
"""
|
||
数据分析智能体
|
||
|
||
职责:
|
||
- 接收用户自然语言需求
|
||
- 生成Python分析代码
|
||
- 执行代码并收集结果
|
||
- 基于执行结果继续生成后续分析代码
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
llm_config: LLMConfig = None,
|
||
output_dir: str = "outputs",
|
||
max_rounds: int = 20,
|
||
force_max_rounds: bool = False,
|
||
):
|
||
"""
|
||
初始化智能体
|
||
|
||
Args:
|
||
config: LLM配置
|
||
output_dir: 输出目录
|
||
max_rounds: 最大对话轮数
|
||
force_max_rounds: 是否强制运行到最大轮数(忽略AI的完成信号)
|
||
"""
|
||
self.config = llm_config or LLMConfig()
|
||
self.llm = LLMHelper(self.config)
|
||
self.base_output_dir = output_dir
|
||
self.max_rounds = max_rounds
|
||
self.force_max_rounds = force_max_rounds
|
||
# 对话历史和上下文
|
||
self.conversation_history = []
|
||
self.analysis_results = []
|
||
self.current_round = 0
|
||
self.session_output_dir = None
|
||
self.executor = None
|
||
self.data_profile = "" # 存储数据画像
|
||
|
||
def _process_response(self, response: str) -> Dict[str, Any]:
|
||
"""
|
||
统一处理LLM响应,判断行动类型并执行相应操作
|
||
|
||
Args:
|
||
response: LLM的响应内容
|
||
|
||
Returns:
|
||
处理结果字典
|
||
"""
|
||
try:
|
||
yaml_data = self.llm.parse_yaml_response(response)
|
||
action = yaml_data.get("action", "generate_code")
|
||
|
||
print(f"🎯 检测到动作: {action}")
|
||
|
||
if action == "analysis_complete":
|
||
return self._handle_analysis_complete(response, yaml_data)
|
||
elif action == "collect_figures":
|
||
return self._handle_collect_figures(response, yaml_data)
|
||
elif action == "generate_code":
|
||
return self._handle_generate_code(response, yaml_data)
|
||
else:
|
||
print(f"⚠️ 未知动作类型: {action},按generate_code处理")
|
||
return self._handle_generate_code(response, yaml_data)
|
||
|
||
except Exception as e:
|
||
print(f"⚠️ 解析响应失败: {str(e)},尝试提取代码并按generate_code处理")
|
||
# 即使YAML解析失败,也尝试提取代码
|
||
extracted_code = extract_code_from_response(response)
|
||
if extracted_code:
|
||
return self._handle_generate_code(response, {"code": extracted_code})
|
||
return self._handle_generate_code(response, {})
|
||
|
||
def _handle_analysis_complete(
|
||
self, response: str, yaml_data: Dict[str, Any]
|
||
) -> Dict[str, Any]:
|
||
"""处理分析完成动作"""
|
||
print("✅ 分析任务完成")
|
||
final_report = yaml_data.get("final_report", "分析完成,无最终报告")
|
||
return {
|
||
"action": "analysis_complete",
|
||
"final_report": final_report,
|
||
"response": response,
|
||
"continue": False,
|
||
}
|
||
|
||
def _handle_collect_figures(
|
||
self, response: str, yaml_data: Dict[str, Any]
|
||
) -> Dict[str, Any]:
|
||
"""处理图片收集动作"""
|
||
print("📊 开始收集图片")
|
||
figures_to_collect = yaml_data.get("figures_to_collect", [])
|
||
|
||
collected_figures = []
|
||
|
||
for figure_info in figures_to_collect:
|
||
figure_number = figure_info.get("figure_number", "未知")
|
||
# 确保figure_number不为None时才用于文件名
|
||
if figure_number != "未知":
|
||
default_filename = f"figure_{figure_number}.png"
|
||
else:
|
||
default_filename = "figure_unknown.png"
|
||
filename = figure_info.get("filename", default_filename)
|
||
file_path = figure_info.get("file_path", "") # 获取具体的文件路径
|
||
description = figure_info.get("description", "")
|
||
analysis = figure_info.get("analysis", "")
|
||
|
||
print(f"📈 收集图片 {figure_number}: {filename}")
|
||
print(f" 📂 路径: {file_path}")
|
||
print(f" 📝 描述: {description}")
|
||
print(f" 🔍 分析: {analysis}")
|
||
|
||
|
||
# 记录图片信息
|
||
collected_figures.append(
|
||
{
|
||
"figure_number": figure_number,
|
||
"filename": filename,
|
||
"file_path": file_path,
|
||
"description": description,
|
||
"analysis": analysis,
|
||
}
|
||
)
|
||
# 验证文件是否存在
|
||
# 只有文件真正存在时才加入列表,防止报告出现裂图
|
||
if file_path and os.path.exists(file_path):
|
||
print(f" ✅ 文件存在: {file_path}")
|
||
# 记录图片信息
|
||
collected_figures.append(
|
||
{
|
||
"figure_number": figure_number,
|
||
"filename": filename,
|
||
"file_path": file_path,
|
||
"description": description,
|
||
"analysis": analysis,
|
||
}
|
||
)
|
||
else:
|
||
if file_path:
|
||
print(f" ⚠️ 文件不存在: {file_path}")
|
||
else:
|
||
print(f" ⚠️ 未提供文件路径")
|
||
|
||
return {
|
||
"action": "collect_figures",
|
||
"collected_figures": collected_figures,
|
||
"response": response,
|
||
"continue": True,
|
||
}
|
||
|
||
def _handle_generate_code(
|
||
self, response: str, yaml_data: Dict[str, Any]
|
||
) -> Dict[str, Any]:
|
||
"""处理代码生成和执行动作"""
|
||
# 从YAML数据中获取代码(更准确)
|
||
code = yaml_data.get("code", "")
|
||
|
||
# 如果YAML中没有代码,尝试从响应中提取
|
||
if not code:
|
||
code = extract_code_from_response(response)
|
||
|
||
# 二次清洗:防止YAML中解析出的code包含markdown标记
|
||
if code:
|
||
code = code.strip()
|
||
if code.startswith("```"):
|
||
import re
|
||
# 去除开头的 ```python 或 ```
|
||
code = re.sub(r"^```[a-zA-Z]*\n", "", code)
|
||
# 去除结尾的 ```
|
||
code = re.sub(r"\n```$", "", code)
|
||
code = code.strip()
|
||
|
||
if code:
|
||
print(f"🔧 执行代码:\n{code}")
|
||
print("-" * 40)
|
||
|
||
# 执行代码
|
||
result = self.executor.execute_code(code)
|
||
|
||
# 格式化执行结果
|
||
feedback = format_execution_result(result)
|
||
print(f"📋 执行反馈:\n{feedback}")
|
||
|
||
return {
|
||
"action": "generate_code",
|
||
"code": code,
|
||
"result": result,
|
||
"feedback": feedback,
|
||
"response": response,
|
||
"continue": True,
|
||
}
|
||
else:
|
||
# 如果没有代码,说明LLM响应格式有问题,需要重新生成
|
||
print("⚠️ 未从响应中提取到可执行代码,要求LLM重新生成")
|
||
return {
|
||
"action": "invalid_response",
|
||
"error": "响应中缺少可执行代码",
|
||
"response": response,
|
||
"continue": True,
|
||
}
|
||
|
||
def analyze(self, user_input: str, files: List[str] = None, session_output_dir: str = None) -> Dict[str, Any]:
|
||
"""
|
||
开始分析流程
|
||
|
||
Args:
|
||
user_input: 用户的自然语言需求
|
||
files: 数据文件路径列表
|
||
session_output_dir: 指定的会话输出目录(可选)
|
||
|
||
Returns:
|
||
分析结果字典
|
||
"""
|
||
# 重置状态
|
||
self.conversation_history = []
|
||
self.analysis_results = []
|
||
self.current_round = 0
|
||
|
||
# 创建本次分析的专用输出目录
|
||
if session_output_dir:
|
||
self.session_output_dir = session_output_dir
|
||
else:
|
||
self.session_output_dir = create_session_output_dir(
|
||
self.base_output_dir, user_input
|
||
)
|
||
|
||
|
||
# 初始化代码执行器,使用会话目录
|
||
self.executor = CodeExecutor(self.session_output_dir)
|
||
|
||
# 设置会话目录变量到执行环境中
|
||
self.executor.set_variable("session_output_dir", self.session_output_dir)
|
||
|
||
# 设用工具生成数据画像
|
||
data_profile = ""
|
||
if files:
|
||
print("🔍 正在生成数据画像...")
|
||
data_profile = load_and_profile_data(files)
|
||
print("✅ 数据画像生成完毕")
|
||
|
||
# 保存到实例变量供最终报告使用
|
||
self.data_profile = data_profile
|
||
|
||
# 构建初始prompt
|
||
initial_prompt = f"""用户需求: {user_input}"""
|
||
if files:
|
||
initial_prompt += f"\n数据文件: {', '.join(files)}"
|
||
|
||
if data_profile:
|
||
initial_prompt += f"\n\n{data_profile}\n\n请根据上述【数据画像】中的统计信息(如高频值、缺失率、数据范围)来制定分析策略。如果发现明显的高频问题或异常分布,请优先进行深度分析。"
|
||
|
||
print(f"🚀 开始数据分析任务")
|
||
print(f"📝 用户需求: {user_input}")
|
||
if files:
|
||
print(f"📁 数据文件: {', '.join(files)}")
|
||
print(f"📂 输出目录: {self.session_output_dir}")
|
||
print(f"🔢 最大轮数: {self.max_rounds}")
|
||
if self.force_max_rounds:
|
||
print(f"⚡ 强制模式: 将运行满 {self.max_rounds} 轮(忽略AI完成信号)")
|
||
print("=" * 60)
|
||
# 添加到对话历史
|
||
self.conversation_history.append({"role": "user", "content": initial_prompt})
|
||
|
||
while self.current_round < self.max_rounds:
|
||
self.current_round += 1
|
||
print(f"\n🔄 第 {self.current_round} 轮分析")
|
||
# 调用LLM生成响应
|
||
try: # 获取当前执行环境的变量信息
|
||
notebook_variables = self.executor.get_environment_info()
|
||
|
||
# 格式化系统提示词,填入动态的notebook变量信息
|
||
formatted_system_prompt = data_analysis_system_prompt.format(
|
||
notebook_variables=notebook_variables
|
||
)
|
||
print(f"🐛 [DEBUG] System Prompt Head:\n{formatted_system_prompt[:500]}...\n[...]")
|
||
print(f"🐛 [DEBUG] System Prompt Rules Check: 'stop_words' in prompt? {'stop_words' in formatted_system_prompt}")
|
||
|
||
response = self.llm.call(
|
||
prompt=self._build_conversation_prompt(),
|
||
system_prompt=formatted_system_prompt,
|
||
)
|
||
|
||
print(f"🤖 助手响应:\n{response}")
|
||
|
||
# 使用统一的响应处理方法
|
||
process_result = self._process_response(response)
|
||
|
||
# 根据处理结果决定是否继续(仅在非强制模式下)
|
||
if not self.force_max_rounds and not process_result.get(
|
||
"continue", True
|
||
):
|
||
print(f"\n✅ 分析完成!")
|
||
break
|
||
|
||
# 添加到对话历史
|
||
self.conversation_history.append(
|
||
{"role": "assistant", "content": response}
|
||
)
|
||
|
||
# 根据动作类型添加不同的反馈
|
||
if process_result["action"] == "generate_code":
|
||
feedback = process_result.get("feedback", "")
|
||
self.conversation_history.append(
|
||
{"role": "user", "content": f"代码执行反馈:\n{feedback}"}
|
||
)
|
||
|
||
# 记录分析结果
|
||
self.analysis_results.append(
|
||
{
|
||
"round": self.current_round,
|
||
"code": process_result.get("code", ""),
|
||
"result": process_result.get("result", {}),
|
||
"response": response,
|
||
}
|
||
)
|
||
elif process_result["action"] == "collect_figures":
|
||
# 记录图片收集结果
|
||
collected_figures = process_result.get("collected_figures", [])
|
||
|
||
missing_figures = process_result.get("missing_figures", [])
|
||
|
||
feedback = f"已收集 {len(collected_figures)} 个有效图片及其分析。"
|
||
if missing_figures:
|
||
feedback += f"\n⚠️ 以下图片未找到,请检查代码是否成功保存了这些图片: {missing_figures}"
|
||
|
||
self.conversation_history.append(
|
||
{
|
||
"role": "user",
|
||
"content": f"图片收集反馈:\n{feedback}\n请继续下一步分析。",
|
||
}
|
||
)
|
||
|
||
# 记录到分析结果中
|
||
self.analysis_results.append(
|
||
{
|
||
"round": self.current_round,
|
||
"action": "collect_figures",
|
||
"collected_figures": collected_figures,
|
||
"missing_figures": missing_figures,
|
||
|
||
"response": response,
|
||
}
|
||
)
|
||
|
||
except Exception as e:
|
||
error_msg = f"LLM调用错误: {str(e)}"
|
||
print(f"❌ {error_msg}")
|
||
self.conversation_history.append(
|
||
{
|
||
"role": "user",
|
||
"content": f"发生错误: {error_msg},请重新生成代码。",
|
||
}
|
||
)
|
||
# 生成最终总结
|
||
if self.current_round >= self.max_rounds:
|
||
print(f"\n⚠️ 已达到最大轮数 ({self.max_rounds}),分析结束")
|
||
|
||
return self._generate_final_report()
|
||
|
||
def _build_conversation_prompt(self) -> str:
|
||
"""构建对话提示词"""
|
||
prompt_parts = []
|
||
|
||
for msg in self.conversation_history:
|
||
role = msg["role"]
|
||
content = msg["content"]
|
||
if role == "user":
|
||
prompt_parts.append(f"用户: {content}")
|
||
else:
|
||
prompt_parts.append(f"助手: {content}")
|
||
|
||
return "\n\n".join(prompt_parts)
|
||
|
||
def _generate_final_report(self) -> Dict[str, Any]:
|
||
"""生成最终分析报告"""
|
||
# 收集所有生成的图片信息
|
||
all_figures = []
|
||
for result in self.analysis_results:
|
||
if result.get("action") == "collect_figures":
|
||
all_figures.extend(result.get("collected_figures", []))
|
||
|
||
print(f"\n📊 开始生成最终分析报告...")
|
||
print(f"📂 输出目录: {self.session_output_dir}")
|
||
print(f"🔢 总轮数: {self.current_round}")
|
||
print(f"📈 收集图片: {len(all_figures)} 个")
|
||
|
||
# 构建用于生成最终报告的提示词
|
||
final_report_prompt = self._build_final_report_prompt(all_figures)
|
||
|
||
try: # 调用LLM生成最终报告
|
||
response = self.llm.call(
|
||
prompt=final_report_prompt,
|
||
system_prompt="你将会接收到一个数据分析任务的最终报告请求,请根据提供的分析结果和图片信息生成完整的分析报告。",
|
||
max_tokens=16384, # 设置较大的token限制以容纳完整报告
|
||
)
|
||
|
||
# 解析响应,提取最终报告
|
||
try:
|
||
# 尝试解析YAML
|
||
yaml_data = self.llm.parse_yaml_response(response)
|
||
|
||
# 情况1: 标准YAML格式,包含 action: analysis_complete
|
||
if yaml_data.get("action") == "analysis_complete":
|
||
final_report_content = yaml_data.get("final_report", response)
|
||
|
||
# 情况2: 解析成功但没字段,或者解析失败
|
||
else:
|
||
# 如果内容看起来像Markdown报告(包含标题),直接使用
|
||
if "# " in response or "## " in response:
|
||
print("⚠️ 未检测到标准YAML动作,但内容疑似Markdown报告,直接采纳")
|
||
final_report_content = response
|
||
else:
|
||
final_report_content = "LLM未返回有效报告内容"
|
||
|
||
except Exception as e:
|
||
# 解析完全失败,直接使用原始响应
|
||
print(f"⚠️ YAML解析失败 ({e}),直接使用原始响应作为报告")
|
||
final_report_content = response
|
||
|
||
print("✅ 最终报告生成完成")
|
||
|
||
except Exception as e:
|
||
print(f"❌ 生成最终报告时出错: {str(e)}")
|
||
final_report_content = f"报告生成失败: {str(e)}"
|
||
|
||
# 保存最终报告到文件
|
||
report_file_path = os.path.join(self.session_output_dir, "最终分析报告.md")
|
||
try:
|
||
with open(report_file_path, "w", encoding="utf-8") as f:
|
||
f.write(final_report_content)
|
||
print(f"📄 最终报告已保存至: {report_file_path}")
|
||
except Exception as e:
|
||
print(f"❌ 保存报告文件失败: {str(e)}")
|
||
|
||
# 返回完整的分析结果
|
||
return {
|
||
"session_output_dir": self.session_output_dir,
|
||
"total_rounds": self.current_round,
|
||
"analysis_results": self.analysis_results,
|
||
"collected_figures": all_figures,
|
||
"conversation_history": self.conversation_history,
|
||
"final_report": final_report_content,
|
||
"report_file_path": report_file_path,
|
||
}
|
||
|
||
def _build_final_report_prompt(self, all_figures: List[Dict[str, Any]]) -> str:
|
||
"""构建用于生成最终报告的提示词"""
|
||
|
||
# 构建图片信息摘要,使用相对路径
|
||
figures_summary = ""
|
||
if all_figures:
|
||
figures_summary = "\n生成的图片及分析:\n"
|
||
for i, figure in enumerate(all_figures, 1):
|
||
filename = figure.get("filename", "未知文件名")
|
||
# 使用相对路径格式,适合在报告中引用
|
||
relative_path = f"./{filename}"
|
||
figures_summary += f"{i}. {filename}\n"
|
||
figures_summary += f" 相对路径: {relative_path}\n"
|
||
figures_summary += f" 描述: {figure.get('description', '无描述')}\n"
|
||
figures_summary += f" 分析: {figure.get('analysis', '无分析')}\n\n"
|
||
else:
|
||
figures_summary = "\n本次分析未生成图片。\n"
|
||
|
||
# 构建代码执行结果摘要(仅包含成功执行的代码块)
|
||
code_results_summary = ""
|
||
success_code_count = 0
|
||
for result in self.analysis_results:
|
||
if result.get("action") != "collect_figures" and result.get("code"):
|
||
exec_result = result.get("result", {})
|
||
if exec_result.get("success"):
|
||
success_code_count += 1
|
||
code_results_summary += f"代码块 {success_code_count}: 执行成功\n"
|
||
if exec_result.get("output"):
|
||
code_results_summary += (
|
||
f"输出: {exec_result.get('output')[:]}\n\n"
|
||
)
|
||
|
||
# 使用 prompts.py 中的统一提示词模板,并添加相对路径使用说明
|
||
prompt = final_report_system_prompt.format(
|
||
current_round=self.current_round,
|
||
session_output_dir=self.session_output_dir,
|
||
data_profile=self.data_profile, # 注入数据画像
|
||
figures_summary=figures_summary,
|
||
code_results_summary=code_results_summary,
|
||
)
|
||
|
||
# 在提示词中明确要求使用相对路径
|
||
prompt += """
|
||
|
||
📁 **图片路径使用说明**:
|
||
报告和图片都在同一目录下,请在报告中使用相对路径引用图片:
|
||
- 格式:
|
||
- 示例:
|
||
- 这样可以确保报告在不同环境下都能正确显示图片
|
||
"""
|
||
|
||
return prompt
|
||
|
||
def reset(self):
|
||
"""重置智能体状态"""
|
||
self.conversation_history = []
|
||
self.analysis_results = []
|
||
self.current_round = 0
|
||
self.executor.reset_environment()
|