216 lines
6.5 KiB
Python
216 lines
6.5 KiB
Python
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
"""
|
|||
|
|
可复用脚本生成器
|
|||
|
|
|
|||
|
|
从分析会话的执行历史中提取成功执行的代码,
|
|||
|
|
合并去重后生成可独立运行的 .py 脚本文件。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
import re
|
|||
|
|
from datetime import datetime
|
|||
|
|
from typing import List, Dict, Any, Set
|
|||
|
|
|
|||
|
|
|
|||
|
|
def extract_imports(code: str) -> Set[str]:
|
|||
|
|
"""从代码中提取所有 import 语句"""
|
|||
|
|
imports = set()
|
|||
|
|
lines = code.split('\n')
|
|||
|
|
for line in lines:
|
|||
|
|
stripped = line.strip()
|
|||
|
|
if stripped.startswith('import ') or stripped.startswith('from '):
|
|||
|
|
# 标准化 import 语句
|
|||
|
|
imports.add(stripped)
|
|||
|
|
return imports
|
|||
|
|
|
|||
|
|
|
|||
|
|
def remove_imports(code: str) -> str:
|
|||
|
|
"""从代码中移除所有 import 语句"""
|
|||
|
|
lines = code.split('\n')
|
|||
|
|
result_lines = []
|
|||
|
|
for line in lines:
|
|||
|
|
stripped = line.strip()
|
|||
|
|
if not stripped.startswith('import ') and not stripped.startswith('from '):
|
|||
|
|
result_lines.append(line)
|
|||
|
|
return '\n'.join(result_lines)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def clean_code_block(code: str) -> str:
|
|||
|
|
"""清理代码块,移除不必要的内容"""
|
|||
|
|
# 移除可能的重复配置代码
|
|||
|
|
patterns_to_skip = [
|
|||
|
|
r"plt\.rcParams\['font\.sans-serif'\]", # 字体配置在模板中统一处理
|
|||
|
|
r"plt\.rcParams\['axes\.unicode_minus'\]",
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
lines = code.split('\n')
|
|||
|
|
result_lines = []
|
|||
|
|
skip_until_empty = False
|
|||
|
|
|
|||
|
|
for line in lines:
|
|||
|
|
stripped = line.strip()
|
|||
|
|
|
|||
|
|
# 跳过空行连续的情况
|
|||
|
|
if not stripped:
|
|||
|
|
if skip_until_empty:
|
|||
|
|
skip_until_empty = False
|
|||
|
|
continue
|
|||
|
|
result_lines.append(line)
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 检查是否需要跳过的模式
|
|||
|
|
should_skip = False
|
|||
|
|
for pattern in patterns_to_skip:
|
|||
|
|
if re.search(pattern, stripped):
|
|||
|
|
should_skip = True
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
if not should_skip:
|
|||
|
|
result_lines.append(line)
|
|||
|
|
|
|||
|
|
return '\n'.join(result_lines)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def generate_reusable_script(
|
|||
|
|
analysis_results: List[Dict[str, Any]],
|
|||
|
|
data_files: List[str],
|
|||
|
|
session_output_dir: str,
|
|||
|
|
user_requirement: str = ""
|
|||
|
|
) -> str:
|
|||
|
|
"""
|
|||
|
|
从分析结果中生成可复用的 Python 脚本
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
analysis_results: 分析过程中记录的结果列表,每个元素包含 'code', 'result' 等
|
|||
|
|
data_files: 原始数据文件路径列表
|
|||
|
|
session_output_dir: 会话输出目录
|
|||
|
|
user_requirement: 用户的原始需求描述
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
生成的脚本文件路径
|
|||
|
|
"""
|
|||
|
|
# 收集所有成功执行的代码
|
|||
|
|
all_imports = set()
|
|||
|
|
code_blocks = []
|
|||
|
|
|
|||
|
|
for result in analysis_results:
|
|||
|
|
# 只处理 generate_code 类型的结果
|
|||
|
|
if result.get("action") == "collect_figures":
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
code = result.get("code", "")
|
|||
|
|
exec_result = result.get("result", {})
|
|||
|
|
|
|||
|
|
# 只收集成功执行的代码
|
|||
|
|
if code and exec_result.get("success", False):
|
|||
|
|
# 提取 imports
|
|||
|
|
imports = extract_imports(code)
|
|||
|
|
all_imports.update(imports)
|
|||
|
|
|
|||
|
|
# 清理代码块
|
|||
|
|
cleaned_code = remove_imports(code)
|
|||
|
|
cleaned_code = clean_code_block(cleaned_code)
|
|||
|
|
|
|||
|
|
# 只添加非空的代码块
|
|||
|
|
if cleaned_code.strip():
|
|||
|
|
code_blocks.append({
|
|||
|
|
"round": result.get("round", 0),
|
|||
|
|
"code": cleaned_code.strip()
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
if not code_blocks:
|
|||
|
|
print("[WARN] 没有成功执行的代码块,跳过脚本生成")
|
|||
|
|
return ""
|
|||
|
|
|
|||
|
|
# 生成脚本内容
|
|||
|
|
now = datetime.now()
|
|||
|
|
timestamp = now.strftime("%Y%m%d_%H%M%S")
|
|||
|
|
|
|||
|
|
# 构建脚本头部
|
|||
|
|
script_header = f'''#!/usr/bin/env python
|
|||
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
"""
|
|||
|
|
数据分析脚本 - 自动生成
|
|||
|
|
=====================================
|
|||
|
|
原始数据文件: {', '.join(data_files)}
|
|||
|
|
生成时间: {now.strftime("%Y-%m-%d %H:%M:%S")}
|
|||
|
|
原始需求: {user_requirement[:200] + '...' if len(user_requirement) > 200 else user_requirement}
|
|||
|
|
=====================================
|
|||
|
|
|
|||
|
|
使用方法:
|
|||
|
|
1. 修改下方 DATA_FILES 列表中的文件路径
|
|||
|
|
2. 修改 OUTPUT_DIR 指定输出目录
|
|||
|
|
3. 运行: python {os.path.basename(session_output_dir)}_分析脚本.py
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
'''
|
|||
|
|
|
|||
|
|
# 添加标准 imports(去重后排序)
|
|||
|
|
standard_imports = sorted([imp for imp in all_imports if imp.startswith('import ')])
|
|||
|
|
from_imports = sorted([imp for imp in all_imports if imp.startswith('from ')])
|
|||
|
|
|
|||
|
|
imports_section = '\n'.join(standard_imports + from_imports)
|
|||
|
|
|
|||
|
|
# 配置区域
|
|||
|
|
config_section = f'''
|
|||
|
|
# ========== 配置区域 (可修改) ==========
|
|||
|
|
|
|||
|
|
# 数据文件路径 - 修改此处以分析不同的数据
|
|||
|
|
DATA_FILES = {repr(data_files)}
|
|||
|
|
|
|||
|
|
# 输出目录 - 图片和报告将保存在此目录
|
|||
|
|
OUTPUT_DIR = "./analysis_output"
|
|||
|
|
|
|||
|
|
# 创建输出目录
|
|||
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|||
|
|
|
|||
|
|
# ========== 字体配置 (中文显示) ==========
|
|||
|
|
import platform
|
|||
|
|
import matplotlib.pyplot as plt
|
|||
|
|
|
|||
|
|
system_name = platform.system()
|
|||
|
|
if system_name == 'Darwin':
|
|||
|
|
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'PingFang SC', 'sans-serif']
|
|||
|
|
elif system_name == 'Windows':
|
|||
|
|
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'sans-serif']
|
|||
|
|
else:
|
|||
|
|
plt.rcParams['font.sans-serif'] = ['WenQuanYi Micro Hei', 'sans-serif']
|
|||
|
|
plt.rcParams['axes.unicode_minus'] = False
|
|||
|
|
|
|||
|
|
# 设置 session_output_dir 变量(兼容原始代码)
|
|||
|
|
session_output_dir = OUTPUT_DIR
|
|||
|
|
'''
|
|||
|
|
|
|||
|
|
# 合并代码块
|
|||
|
|
code_section = "\n# ========== 分析代码 ==========\n\n"
|
|||
|
|
|
|||
|
|
for i, block in enumerate(code_blocks, 1):
|
|||
|
|
code_section += f"# --- 第 {block['round']} 轮分析 ---\n"
|
|||
|
|
code_section += block['code'] + "\n\n"
|
|||
|
|
|
|||
|
|
# 脚本尾部
|
|||
|
|
script_footer = '''
|
|||
|
|
# ========== 完成 ==========
|
|||
|
|
print("\\n" + "=" * 50)
|
|||
|
|
print("[OK] 分析完成!")
|
|||
|
|
print(f"[OUTPUT] 输出目录: {os.path.abspath(OUTPUT_DIR)}")
|
|||
|
|
print("=" * 50)
|
|||
|
|
'''
|
|||
|
|
|
|||
|
|
# 组装完整脚本
|
|||
|
|
full_script = script_header + imports_section + config_section + code_section + script_footer
|
|||
|
|
|
|||
|
|
# 保存脚本文件
|
|||
|
|
script_filename = f"分析脚本_{timestamp}.py"
|
|||
|
|
script_path = os.path.join(session_output_dir, script_filename)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
with open(script_path, 'w', encoding='utf-8') as f:
|
|||
|
|
f.write(full_script)
|
|||
|
|
print(f"[OK] 可复用脚本已生成: {script_path}")
|
|||
|
|
return script_path
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[ERROR] 保存脚本失败: {e}")
|
|||
|
|
return ""
|