Files
iov_data_analysis_agent/utils/code_executor.py
2026-01-09 16:52:45 +08:00

502 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""
安全的代码执行器,基于 IPython 提供 notebook 环境下的代码执行功能
"""
import os
import sys
import ast
import traceback
import io
from typing import Dict, Any, List, Optional, Tuple
from contextlib import redirect_stdout, redirect_stderr
from IPython.core.interactiveshell import InteractiveShell
from IPython.utils.capture import capture_output
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
class CodeExecutor:
"""
安全的代码执行器,限制依赖库,捕获输出,支持图片保存与路径输出
"""
ALLOWED_IMPORTS = {
"pandas",
"pd",
"numpy",
"glob",
"np",
"subprocess",
"matplotlib",
"matplotlib.pyplot",
"plt",
"seaborn",
"sns",
"duckdb",
"scipy",
"sklearn",
"sklearn.feature_extraction.text",
"sklearn.preprocessing",
"sklearn.model_selection",
"sklearn.metrics",
"sklearn.ensemble",
"sklearn.linear_model",
"sklearn.cluster",
"sklearn.decomposition",
"sklearn.manifold",
"statsmodels",
"plotly",
"dash",
"requests",
"urllib",
"os",
"sys",
"json",
"csv",
"datetime",
"time",
"math",
"statistics",
"re",
"pathlib",
"io",
"collections",
"itertools",
"functools",
"operator",
"warnings",
"logging",
"copy",
"pickle",
"gzip",
"zipfile",
"yaml",
"typing",
"dataclasses",
"enum",
"sqlite3",
"jieba",
"wordcloud",
"PIL",
"random",
"networkx",
}
def __init__(self, output_dir: str = "outputs"):
"""
初始化代码执行器
Args:
output_dir: 输出目录,用于保存图片和文件
"""
self.output_dir = os.path.abspath(output_dir)
os.makedirs(self.output_dir, exist_ok=True)
# 初始化 IPython shell
self.shell = InteractiveShell.instance()
# 设置中文字体
self._setup_chinese_font()
# 预导入常用库
self._setup_common_imports()
# 图片计数器
self.image_counter = 0
def _setup_chinese_font(self):
"""设置matplotlib中文字体显示"""
try:
# 设置matplotlib使用Agg backend避免GUI问题
matplotlib.use("Agg")
# 获取系统可用字体
available_fonts = [f.name for f in fm.fontManager.ttflist]
# 设置matplotlib使用系统可用中文字体
# macOS系统常用中文字体按优先级排序
chinese_fonts = [
"Hiragino Sans GB", # macOS中文简体
"Songti SC", # macOS宋体简体
"PingFang SC", # macOS苹方简体
"Heiti SC", # macOS黑体简体
"Heiti TC", # macOS黑体繁体
"PingFang HK", # macOS苹方香港
"SimHei", # Windows黑体
"STHeiti", # 华文黑体
"WenQuanYi Micro Hei", # Linux文泉驿微米黑
"DejaVu Sans", # 默认无衬线字体
"Arial Unicode MS", # Arial Unicode
]
# 检查系统中实际存在的字体
system_chinese_fonts = [
font for font in chinese_fonts if font in available_fonts
]
# 如果没有找到合适的中文字体,尝试更宽松的搜索
if not system_chinese_fonts:
print("警告:未找到精确匹配的中文字体,尝试更宽松的搜索...")
# 更宽松的字体匹配(包含部分名称)
fallback_fonts = []
for available_font in available_fonts:
if any(
keyword in available_font
for keyword in [
"Hei",
"Song",
"Fang",
"Kai",
"Hiragino",
"PingFang",
"ST",
]
):
fallback_fonts.append(available_font)
if fallback_fonts:
system_chinese_fonts = fallback_fonts[:3] # 取前3个匹配的字体
print(f"找到备选中文字体: {system_chinese_fonts}")
else:
print("警告:系统中未找到合适的中文字体,使用系统默认字体")
system_chinese_fonts = ["DejaVu Sans", "Arial Unicode MS"]
# 设置字体配置
plt.rcParams["font.sans-serif"] = system_chinese_fonts + [
"DejaVu Sans",
"Arial Unicode MS",
]
plt.rcParams["axes.unicode_minus"] = False
plt.rcParams["font.family"] = "sans-serif"
# 在shell中也设置相同的字体配置
font_list_str = str(
system_chinese_fonts + ["DejaVu Sans", "Arial Unicode MS"]
)
self.shell.run_cell(
f"""
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
# 设置中文字体
plt.rcParams['font.sans-serif'] = {font_list_str}
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['font.family'] = 'sans-serif'
# 确保matplotlib缓存目录可写
import os
cache_dir = os.path.expanduser('~/.matplotlib')
if not os.path.exists(cache_dir):
os.makedirs(cache_dir, exist_ok=True)
os.environ['MPLCONFIGDIR'] = cache_dir
"""
)
except Exception as e:
print(f"设置中文字体失败: {e}")
# 即使失败也要设置基本的matplotlib配置
try:
matplotlib.use("Agg")
plt.rcParams["axes.unicode_minus"] = False
except:
pass
def _setup_common_imports(self):
"""预导入常用库"""
common_imports = """
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import duckdb
import os
import json
from IPython.display import display
"""
try:
self.shell.run_cell(common_imports)
# 确保display函数在shell的用户命名空间中可用
from IPython.display import display
self.shell.user_ns["display"] = display
except Exception as e:
print(f"预导入库失败: {e}")
def _check_code_safety(self, code: str) -> Tuple[bool, str]:
"""
检查代码安全性,限制导入的库
Returns:
(is_safe, error_message)
"""
try:
tree = ast.parse(code)
except SyntaxError as e:
return False, f"语法错误: {e}"
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
# 获取根包名 (e.g. sklearn.preprocessing -> sklearn)
root_package = alias.name.split('.')[0]
if root_package not in self.ALLOWED_IMPORTS and alias.name not in self.ALLOWED_IMPORTS:
return False, f"不允许的导入: {alias.name}"
elif isinstance(node, ast.ImportFrom):
if node.module:
root_package = node.module.split('.')[0]
if root_package not in self.ALLOWED_IMPORTS and node.module not in self.ALLOWED_IMPORTS:
return False, f"不允许的导入: {node.module}"
# 检查属性访问防止通过os.system等方式绕过
elif isinstance(node, ast.Attribute):
# 检查是否访问了os模块的属性
if isinstance(node.value, ast.Name) and node.value.id == "os":
# 允许的os子模块和函数白名单
allowed_os_attributes = {
"path", "environ", "getcwd", "listdir", "makedirs", "mkdir", "remove", "rmdir",
"path.join", "path.exists", "path.abspath", "path.dirname",
"path.basename", "path.splitext", "path.isdir", "path.isfile",
"sep", "name", "linesep", "stat", "getpid"
}
# 检查直接属性访问 (如 os.getcwd)
if node.attr not in allowed_os_attributes:
# 进一步检查如果是 os.path.xxx 这种形式
# Note: ast.Attribute 嵌套结构比较复杂,简单处理只允许 os.path 和上述白名单
if node.attr == "path":
pass # 允许访问 os.path
else:
return False, f"不允许的os属性访问: os.{node.attr}"
# 检查危险函数调用
elif isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
if node.func.id in ["exec", "eval", "open", "__import__"]:
return False, f"不允许的函数调用: {node.func.id}"
return True, ""
def get_current_figures_info(self) -> List[Dict[str, Any]]:
"""获取当前matplotlib图形信息但不自动保存"""
figures_info = []
# 获取当前所有图形
fig_nums = plt.get_fignums()
for fig_num in fig_nums:
fig = plt.figure(fig_num)
if fig.get_axes(): # 只处理有内容的图形
figures_info.append(
{
"figure_number": fig_num,
"axes_count": len(fig.get_axes()),
"figure_size": fig.get_size_inches().tolist(),
"has_content": True,
}
)
return figures_info
def _format_table_output(self, obj: Any) -> str:
"""格式化表格输出,限制行数"""
if hasattr(obj, "shape") and hasattr(obj, "head"): # pandas DataFrame
rows, cols = obj.shape
print(f"\n数据表形状: {rows}行 x {cols}")
print(f"列名: {list(obj.columns)}")
if rows <= 15:
return str(obj)
else:
head_part = obj.head(5)
tail_part = obj.tail(5)
return f"{head_part}\n...\n(省略 {rows-10} 行)\n...\n{tail_part}"
return str(obj)
def execute_code(self, code: str) -> Dict[str, Any]:
"""
执行代码并返回结果
Args:
code: 要执行的Python代码
Returns:
{
'success': bool,
'output': str,
'error': str,
'variables': Dict[str, Any] # 新生成的重要变量
}
"""
# 检查代码安全性
is_safe, safety_error = self._check_code_safety(code)
if not is_safe:
return {
"success": False,
"output": "",
"error": f"代码安全检查失败: {safety_error}",
"variables": {},
}
# 记录执行前的变量
vars_before = set(self.shell.user_ns.keys())
try:
# 使用IPython的capture_output来捕获所有输出
with capture_output() as captured:
result = self.shell.run_cell(code)
# 检查执行结果
if result.error_before_exec:
error_msg = str(result.error_before_exec)
return {
"success": False,
"output": captured.stdout,
"error": f"执行前错误: {error_msg}",
"variables": {},
}
if result.error_in_exec:
error_msg = str(result.error_in_exec)
return {
"success": False,
"output": captured.stdout,
"error": f"执行错误: {error_msg}",
"variables": {},
}
# 获取输出
output = captured.stdout
# 如果有返回值,添加到输出
if result.result is not None:
formatted_result = self._format_table_output(result.result)
output += f"\n{formatted_result}"
# 记录新产生的重要变量(简化版本)
vars_after = set(self.shell.user_ns.keys())
new_vars = vars_after - vars_before
# 只记录新创建的DataFrame等重要数据结构
important_new_vars = {}
for var_name in new_vars:
if not var_name.startswith("_"):
try:
var_value = self.shell.user_ns[var_name]
if hasattr(var_value, "shape"): # pandas DataFrame, numpy array
important_new_vars[var_name] = (
f"{type(var_value).__name__} with shape {var_value.shape}"
)
elif var_name in ["session_output_dir"]: # 重要的配置变量
important_new_vars[var_name] = str(var_value)
except:
pass
# --- 自动保存机制 start ---
# 检查是否有未关闭的图片,如果有,自动保存
try:
open_fig_nums = plt.get_fignums()
if open_fig_nums:
for fig_num in open_fig_nums:
fig = plt.figure(fig_num)
# 生成自动保存的文件名
auto_filename = f"autosave_fig_{self.image_counter}_{fig_num}.png"
auto_filepath = os.path.join(self.output_dir, auto_filename)
try:
# 尝试保存
fig.savefig(auto_filepath, bbox_inches='tight')
print(f"💾 [Auto-Save] 检测到未闭合图表,已安全保存至: {auto_filepath}")
# 添加到输出中告知Agent
output += f"\n[Auto-Save] ⚠️ 检测到Figure {fig_num}未关闭,系统已自动保存为: {auto_filename}"
self.image_counter += 1
except Exception as e:
print(f"⚠️ [Auto-Save] 保存失败: {e}")
finally:
plt.close(fig_num)
except Exception as e:
print(f"⚠️ [Auto-Save Global] 异常: {e}")
# --- 自动保存机制 end ---
return {
"success": True,
"output": output,
"error": "",
"variables": important_new_vars,
}
except Exception as e:
return {
"success": False,
"output": captured.stdout if "captured" in locals() else "",
"error": f"执行异常: {str(e)}\n{traceback.format_exc()}",
"variables": {},
}
def reset_environment(self):
"""重置执行环境"""
self.shell.reset()
self._setup_common_imports()
self._setup_chinese_font()
plt.close("all")
self.image_counter = 0
def set_variable(self, name: str, value: Any):
"""设置执行环境中的变量"""
self.shell.user_ns[name] = value
def get_environment_info(self) -> str:
"""获取当前执行环境的变量信息,用于系统提示词"""
info_parts = []
# 获取重要的数据变量
important_vars = {}
for var_name, var_value in self.shell.user_ns.items():
if not var_name.startswith("_") and var_name not in [
"In",
"Out",
"get_ipython",
"exit",
"quit",
]:
try:
if hasattr(var_value, "shape"): # pandas DataFrame, numpy array
important_vars[var_name] = (
f"{type(var_value).__name__} with shape {var_value.shape}"
)
elif var_name in ["session_output_dir"]: # 重要的路径变量
important_vars[var_name] = str(var_value)
elif (
isinstance(var_value, (int, float, str, bool))
and len(str(var_value)) < 100
):
important_vars[var_name] = (
f"{type(var_value).__name__}: {var_value}"
)
elif hasattr(var_value, "__module__") and var_value.__module__ in [
"pandas",
"numpy",
"matplotlib.pyplot",
]:
important_vars[var_name] = f"导入的模块: {var_value.__module__}"
except:
continue
if important_vars:
info_parts.append("当前环境变量:")
for var_name, var_info in important_vars.items():
info_parts.append(f"- {var_name}: {var_info}")
else:
info_parts.append("当前环境已预装pandas, numpy, matplotlib等库")
# 添加输出目录信息
if "session_output_dir" in self.shell.user_ns:
info_parts.append(
f"图片保存目录: session_output_dir = '{self.shell.user_ns['session_output_dir']}'"
)
return "\n".join(info_parts)