Files
iov_data_analysis_agent/utils/script_generator.py
2026-01-31 18:00:05 +08:00

216 lines
6.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""
可复用脚本生成器
从分析会话的执行历史中提取成功执行的代码,
合并去重后生成可独立运行的 .py 脚本文件。
"""
import os
import re
from datetime import datetime
from typing import List, Dict, Any, Set
def extract_imports(code: str) -> Set[str]:
"""从代码中提取所有 import 语句"""
imports = set()
lines = code.split('\n')
for line in lines:
stripped = line.strip()
if stripped.startswith('import ') or stripped.startswith('from '):
# 标准化 import 语句
imports.add(stripped)
return imports
def remove_imports(code: str) -> str:
"""从代码中移除所有 import 语句"""
lines = code.split('\n')
result_lines = []
for line in lines:
stripped = line.strip()
if not stripped.startswith('import ') and not stripped.startswith('from '):
result_lines.append(line)
return '\n'.join(result_lines)
def clean_code_block(code: str) -> str:
"""清理代码块,移除不必要的内容"""
# 移除可能的重复配置代码
patterns_to_skip = [
r"plt\.rcParams\['font\.sans-serif'\]", # 字体配置在模板中统一处理
r"plt\.rcParams\['axes\.unicode_minus'\]",
]
lines = code.split('\n')
result_lines = []
skip_until_empty = False
for line in lines:
stripped = line.strip()
# 跳过空行连续的情况
if not stripped:
if skip_until_empty:
skip_until_empty = False
continue
result_lines.append(line)
continue
# 检查是否需要跳过的模式
should_skip = False
for pattern in patterns_to_skip:
if re.search(pattern, stripped):
should_skip = True
break
if not should_skip:
result_lines.append(line)
return '\n'.join(result_lines)
def generate_reusable_script(
analysis_results: List[Dict[str, Any]],
data_files: List[str],
session_output_dir: str,
user_requirement: str = ""
) -> str:
"""
从分析结果中生成可复用的 Python 脚本
Args:
analysis_results: 分析过程中记录的结果列表,每个元素包含 'code', 'result'
data_files: 原始数据文件路径列表
session_output_dir: 会话输出目录
user_requirement: 用户的原始需求描述
Returns:
生成的脚本文件路径
"""
# 收集所有成功执行的代码
all_imports = set()
code_blocks = []
for result in analysis_results:
# 只处理 generate_code 类型的结果
if result.get("action") == "collect_figures":
continue
code = result.get("code", "")
exec_result = result.get("result", {})
# 只收集成功执行的代码
if code and exec_result.get("success", False):
# 提取 imports
imports = extract_imports(code)
all_imports.update(imports)
# 清理代码块
cleaned_code = remove_imports(code)
cleaned_code = clean_code_block(cleaned_code)
# 只添加非空的代码块
if cleaned_code.strip():
code_blocks.append({
"round": result.get("round", 0),
"code": cleaned_code.strip()
})
if not code_blocks:
print("[WARN] 没有成功执行的代码块,跳过脚本生成")
return ""
# 生成脚本内容
now = datetime.now()
timestamp = now.strftime("%Y%m%d_%H%M%S")
# 构建脚本头部
script_header = f'''#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
数据分析脚本 - 自动生成
=====================================
原始数据文件: {', '.join(data_files)}
生成时间: {now.strftime("%Y-%m-%d %H:%M:%S")}
原始需求: {user_requirement[:200] + '...' if len(user_requirement) > 200 else user_requirement}
=====================================
使用方法:
1. 修改下方 DATA_FILES 列表中的文件路径
2. 修改 OUTPUT_DIR 指定输出目录
3. 运行: python {os.path.basename(session_output_dir)}_分析脚本.py
"""
import os
'''
# 添加标准 imports去重后排序
standard_imports = sorted([imp for imp in all_imports if imp.startswith('import ')])
from_imports = sorted([imp for imp in all_imports if imp.startswith('from ')])
imports_section = '\n'.join(standard_imports + from_imports)
# 配置区域
config_section = f'''
# ========== 配置区域 (可修改) ==========
# 数据文件路径 - 修改此处以分析不同的数据
DATA_FILES = {repr(data_files)}
# 输出目录 - 图片和报告将保存在此目录
OUTPUT_DIR = "./analysis_output"
# 创建输出目录
os.makedirs(OUTPUT_DIR, exist_ok=True)
# ========== 字体配置 (中文显示) ==========
import platform
import matplotlib.pyplot as plt
system_name = platform.system()
if system_name == 'Darwin':
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'PingFang SC', 'sans-serif']
elif system_name == 'Windows':
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'sans-serif']
else:
plt.rcParams['font.sans-serif'] = ['WenQuanYi Micro Hei', 'sans-serif']
plt.rcParams['axes.unicode_minus'] = False
# 设置 session_output_dir 变量(兼容原始代码)
session_output_dir = OUTPUT_DIR
'''
# 合并代码块
code_section = "\n# ========== 分析代码 ==========\n\n"
for i, block in enumerate(code_blocks, 1):
code_section += f"# --- 第 {block['round']} 轮分析 ---\n"
code_section += block['code'] + "\n\n"
# 脚本尾部
script_footer = '''
# ========== 完成 ==========
print("\\n" + "=" * 50)
print("[OK] 分析完成!")
print(f"[OUTPUT] 输出目录: {os.path.abspath(OUTPUT_DIR)}")
print("=" * 50)
'''
# 组装完整脚本
full_script = script_header + imports_section + config_section + code_section + script_footer
# 保存脚本文件
script_filename = f"分析脚本_{timestamp}.py"
script_path = os.path.join(session_output_dir, script_filename)
try:
with open(script_path, 'w', encoding='utf-8') as f:
f.write(full_script)
print(f"[OK] 可复用脚本已生成: {script_path}")
return script_path
except Exception as e:
print(f"[ERROR] 保存脚本失败: {e}")
return ""