Add web session analysis platform with follow-up topics

This commit is contained in:
2026-03-09 22:23:00 +08:00
commit 17ce711e49
30 changed files with 10681 additions and 0 deletions

16
utils/__init__.py Normal file
View File

@@ -0,0 +1,16 @@
# -*- coding: utf-8 -*-
"""
工具模块初始化文件
"""
from utils.code_executor import CodeExecutor
from utils.execution_session_client import ExecutionSessionClient
from utils.llm_helper import LLMHelper
from utils.fallback_openai_client import AsyncFallbackOpenAIClient
__all__ = [
"CodeExecutor",
"ExecutionSessionClient",
"LLMHelper",
"AsyncFallbackOpenAIClient",
]

456
utils/code_executor.py Normal file
View File

@@ -0,0 +1,456 @@
# -*- coding: utf-8 -*-
"""
安全的代码执行器,基于 IPython 提供 notebook 环境下的代码执行功能
"""
import os
import sys
import ast
import traceback
import io
from typing import Dict, Any, List, Optional, Tuple
from contextlib import redirect_stdout, redirect_stderr
from IPython.core.interactiveshell import InteractiveShell
from IPython.utils.capture import capture_output
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
class CodeExecutor:
"""
安全的代码执行器,限制依赖库,捕获输出,支持图片保存与路径输出
"""
ALLOWED_IMPORTS = {
"pandas",
"pd",
"numpy",
"np",
"matplotlib",
"matplotlib.pyplot",
"plt",
"seaborn",
"sns",
"duckdb",
"scipy",
"sklearn",
"statsmodels",
"plotly",
"dash",
"requests",
"urllib",
"os",
"sys",
"json",
"csv",
"datetime",
"time",
"math",
"statistics",
"re",
"pathlib",
"io",
"collections",
"itertools",
"functools",
"operator",
"warnings",
"logging",
"copy",
"pickle",
"gzip",
"zipfile",
"yaml",
"typing",
"dataclasses",
"enum",
"sqlite3",
"jieba",
"wordcloud",
"PIL",
"random",
"networkx",
}
def __init__(self, output_dir: str = "outputs"):
"""
初始化代码执行器
Args:
output_dir: 输出目录,用于保存图片和文件
"""
self.output_dir = os.path.abspath(output_dir)
os.makedirs(self.output_dir, exist_ok=True)
# 为每个执行器创建独立的 shell避免跨分析任务共享状态
self.shell = InteractiveShell()
# 初始化隔离执行环境
self.reset_environment()
# 图片计数器
self.image_counter = 0
def _setup_chinese_font(self):
"""设置matplotlib中文字体显示"""
try:
# 设置matplotlib使用Agg backend避免GUI问题
matplotlib.use("Agg")
# 获取系统可用字体
available_fonts = [f.name for f in fm.fontManager.ttflist]
# 设置matplotlib使用系统可用中文字体
# macOS系统常用中文字体按优先级排序
chinese_fonts = [
"Hiragino Sans GB", # macOS中文简体
"Songti SC", # macOS宋体简体
"PingFang SC", # macOS苹方简体
"Heiti SC", # macOS黑体简体
"Heiti TC", # macOS黑体繁体
"PingFang HK", # macOS苹方香港
"SimHei", # Windows黑体
"STHeiti", # 华文黑体
"WenQuanYi Micro Hei", # Linux文泉驿微米黑
"DejaVu Sans", # 默认无衬线字体
"Arial Unicode MS", # Arial Unicode
]
# 检查系统中实际存在的字体
system_chinese_fonts = [
font for font in chinese_fonts if font in available_fonts
]
# 如果没有找到合适的中文字体,尝试更宽松的搜索
if not system_chinese_fonts:
print("警告:未找到精确匹配的中文字体,尝试更宽松的搜索...")
# 更宽松的字体匹配(包含部分名称)
fallback_fonts = []
for available_font in available_fonts:
if any(
keyword in available_font
for keyword in [
"Hei",
"Song",
"Fang",
"Kai",
"Hiragino",
"PingFang",
"ST",
]
):
fallback_fonts.append(available_font)
if fallback_fonts:
system_chinese_fonts = fallback_fonts[:3] # 取前3个匹配的字体
print(f"找到备选中文字体: {system_chinese_fonts}")
else:
print("警告:系统中未找到合适的中文字体,使用系统默认字体")
system_chinese_fonts = ["DejaVu Sans", "Arial Unicode MS"]
# 设置字体配置
plt.rcParams["font.sans-serif"] = system_chinese_fonts + [
"DejaVu Sans",
"Arial Unicode MS",
]
plt.rcParams["axes.unicode_minus"] = False
plt.rcParams["font.family"] = "sans-serif"
# 在shell中也设置相同的字体配置
font_list_str = str(
system_chinese_fonts + ["DejaVu Sans", "Arial Unicode MS"]
)
self.shell.run_cell(
f"""
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
# 设置中文字体
plt.rcParams['font.sans-serif'] = {font_list_str}
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['font.family'] = 'sans-serif'
# 确保matplotlib缓存目录可写
import os
cache_dir = os.path.expanduser('~/.matplotlib')
if not os.path.exists(cache_dir):
os.makedirs(cache_dir, exist_ok=True)
os.environ['MPLCONFIGDIR'] = cache_dir
"""
)
except Exception as e:
print(f"设置中文字体失败: {e}")
# 即使失败也要设置基本的matplotlib配置
try:
matplotlib.use("Agg")
plt.rcParams["axes.unicode_minus"] = False
except:
pass
def _setup_common_imports(self):
"""预导入常用库"""
common_imports = """
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import duckdb
import os
import json
from IPython.display import display
"""
try:
self.shell.run_cell(common_imports)
# 确保display函数在shell的用户命名空间中可用
from IPython.display import display
self.shell.user_ns["display"] = display
except Exception as e:
print(f"预导入库失败: {e}")
def _check_code_safety(self, code: str) -> Tuple[bool, str]:
"""
检查代码安全性,限制导入的库
Returns:
(is_safe, error_message)
"""
try:
tree = ast.parse(code)
except SyntaxError as e:
return False, f"语法错误: {e}"
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
if alias.name not in self.ALLOWED_IMPORTS:
return False, f"不允许的导入: {alias.name}"
elif isinstance(node, ast.ImportFrom):
if node.module not in self.ALLOWED_IMPORTS:
return False, f"不允许的导入: {node.module}"
# 检查属性访问防止通过os.system等方式绕过
elif isinstance(node, ast.Attribute):
# 检查是否访问了os模块的属性
if isinstance(node.value, ast.Name) and node.value.id == "os":
# 允许的os子模块和函数白名单
allowed_os_attributes = {
"path", "environ", "getcwd", "listdir", "makedirs", "mkdir", "remove", "rmdir",
"path.join", "path.exists", "path.abspath", "path.dirname",
"path.basename", "path.splitext", "path.isdir", "path.isfile",
"sep", "name", "linesep", "stat", "getpid"
}
# 检查直接属性访问 (如 os.getcwd)
if node.attr not in allowed_os_attributes:
# 进一步检查如果是 os.path.xxx 这种形式
# Note: ast.Attribute 嵌套结构比较复杂,简单处理只允许 os.path 和上述白名单
if node.attr == "path":
pass # 允许访问 os.path
else:
return False, f"不允许的os属性访问: os.{node.attr}"
# 检查危险函数调用
elif isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
if node.func.id in ["exec", "eval", "__import__"]:
return False, f"不允许的函数调用: {node.func.id}"
return True, ""
def get_current_figures_info(self) -> List[Dict[str, Any]]:
"""获取当前matplotlib图形信息但不自动保存"""
figures_info = []
# 获取当前所有图形
fig_nums = plt.get_fignums()
for fig_num in fig_nums:
fig = plt.figure(fig_num)
if fig.get_axes(): # 只处理有内容的图形
figures_info.append(
{
"figure_number": fig_num,
"axes_count": len(fig.get_axes()),
"figure_size": fig.get_size_inches().tolist(),
"has_content": True,
}
)
return figures_info
def _format_table_output(self, obj: Any) -> str:
"""格式化表格输出,限制行数"""
if hasattr(obj, "shape") and hasattr(obj, "head"): # pandas DataFrame
rows, cols = obj.shape
print(f"\n数据表形状: {rows}行 x {cols}")
print(f"列名: {list(obj.columns)}")
if rows <= 15:
return str(obj)
else:
head_part = obj.head(5)
tail_part = obj.tail(5)
return f"{head_part}\n...\n(省略 {rows-10} 行)\n...\n{tail_part}"
return str(obj)
def execute_code(self, code: str) -> Dict[str, Any]:
"""
执行代码并返回结果
Args:
code: 要执行的Python代码
Returns:
{
'success': bool,
'output': str,
'error': str,
'variables': Dict[str, Any] # 新生成的重要变量
}
"""
# 检查代码安全性
is_safe, safety_error = self._check_code_safety(code)
if not is_safe:
return {
"success": False,
"output": "",
"error": f"代码安全检查失败: {safety_error}",
"variables": {},
}
# 记录执行前的变量
vars_before = set(self.shell.user_ns.keys())
try:
# 使用IPython的capture_output来捕获所有输出
with capture_output() as captured:
result = self.shell.run_cell(code)
# 检查执行结果
if result.error_before_exec:
error_msg = str(result.error_before_exec)
return {
"success": False,
"output": captured.stdout,
"error": f"执行前错误: {error_msg}",
"variables": {},
}
if result.error_in_exec:
error_msg = str(result.error_in_exec)
return {
"success": False,
"output": captured.stdout,
"error": f"执行错误: {error_msg}",
"variables": {},
}
# 获取输出
output = captured.stdout
# 如果有返回值,添加到输出
if result.result is not None:
formatted_result = self._format_table_output(result.result)
output += f"\n{formatted_result}"
# 记录新产生的重要变量(简化版本)
vars_after = set(self.shell.user_ns.keys())
new_vars = vars_after - vars_before
# 只记录新创建的DataFrame等重要数据结构
important_new_vars = {}
for var_name in new_vars:
if not var_name.startswith("_"):
try:
var_value = self.shell.user_ns[var_name]
if hasattr(var_value, "shape"): # pandas DataFrame, numpy array
important_new_vars[var_name] = (
f"{type(var_value).__name__} with shape {var_value.shape}"
)
elif var_name in ["session_output_dir"]: # 重要的配置变量
important_new_vars[var_name] = str(var_value)
except:
pass
return {
"success": True,
"output": output,
"error": "",
"variables": important_new_vars,
}
except Exception as e:
return {
"success": False,
"output": captured.stdout if "captured" in locals() else "",
"error": f"执行异常: {str(e)}\n{traceback.format_exc()}",
"variables": {},
}
def reset_environment(self):
"""重置执行环境"""
self.shell.reset()
self._setup_common_imports()
self._setup_chinese_font()
plt.close("all")
self.image_counter = 0
def set_variable(self, name: str, value: Any):
"""设置执行环境中的变量"""
self.shell.user_ns[name] = value
def get_environment_info(self) -> str:
"""获取当前执行环境的变量信息,用于系统提示词"""
info_parts = []
# 获取重要的数据变量
important_vars = {}
for var_name, var_value in self.shell.user_ns.items():
if not var_name.startswith("_") and var_name not in [
"In",
"Out",
"get_ipython",
"exit",
"quit",
]:
try:
if hasattr(var_value, "shape"): # pandas DataFrame, numpy array
important_vars[var_name] = (
f"{type(var_value).__name__} with shape {var_value.shape}"
)
elif var_name in ["session_output_dir"]: # 重要的路径变量
important_vars[var_name] = str(var_value)
elif (
isinstance(var_value, (int, float, str, bool))
and len(str(var_value)) < 100
):
important_vars[var_name] = (
f"{type(var_value).__name__}: {var_value}"
)
elif hasattr(var_value, "__module__") and var_value.__module__ in [
"pandas",
"numpy",
"matplotlib.pyplot",
]:
important_vars[var_name] = f"导入的模块: {var_value.__module__}"
except:
continue
if important_vars:
info_parts.append("当前环境变量:")
for var_name, var_info in important_vars.items():
info_parts.append(f"- {var_name}: {var_info}")
else:
info_parts.append("当前环境已预装pandas, numpy, matplotlib等库")
# 添加输出目录信息
if "session_output_dir" in self.shell.user_ns:
info_parts.append(
f"图片保存目录: session_output_dir = '{self.shell.user_ns['session_output_dir']}'"
)
return "\n".join(info_parts)

View File

@@ -0,0 +1,15 @@
import os
from datetime import datetime
def create_session_output_dir(base_output_dir, user_input: str) -> str:
"""为本次分析创建独立的输出目录"""
# 使用当前时间创建唯一的会话目录名格式YYYYMMDD_HHMMSS
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
session_id = timestamp
dir_name = f"session_{session_id}"
session_dir = os.path.join(base_output_dir, dir_name)
os.makedirs(session_dir, exist_ok=True)
return session_dir

90
utils/data_loader.py Normal file
View File

@@ -0,0 +1,90 @@
# -*- coding: utf-8 -*-
import os
import pandas as pd
import io
def load_and_profile_data(file_paths: list) -> str:
"""
加载数据并生成数据画像
Args:
file_paths: 文件路径列表
Returns:
包含数据画像的Markdown字符串
"""
profile_summary = "# 数据画像报告 (Data Profile)\n\n"
if not file_paths:
return profile_summary + "未提供数据文件。"
for file_path in file_paths:
file_name = os.path.basename(file_path)
profile_summary += f"## 文件: {file_name}\n\n"
if not os.path.exists(file_path):
profile_summary += f"⚠️ 文件不存在: {file_path}\n\n"
continue
try:
# 根据扩展名选择加载方式
ext = os.path.splitext(file_path)[1].lower()
if ext == '.csv':
# 尝试多种编码
try:
df = pd.read_csv(file_path, encoding='utf-8')
except UnicodeDecodeError:
try:
df = pd.read_csv(file_path, encoding='gbk')
except Exception:
df = pd.read_csv(file_path, encoding='latin1')
elif ext in ['.xlsx', '.xls']:
df = pd.read_excel(file_path)
else:
profile_summary += f"⚠️ 不支持的文件格式: {ext}\n\n"
continue
# 基础信息
rows, cols = df.shape
profile_summary += f"- **维度**: {rows} 行 x {cols}\n"
profile_summary += f"- **列名**: `{', '.join(df.columns)}`\n\n"
profile_summary += "### 列详细分布:\n"
# 遍历分析每列
for col in df.columns:
dtype = df[col].dtype
null_count = df[col].isnull().sum()
null_ratio = (null_count / rows) * 100
profile_summary += f"#### {col} ({dtype})\n"
if null_count > 0:
profile_summary += f"- ⚠️ 空值: {null_count} ({null_ratio:.1f}%)\n"
# 数值列分析
if pd.api.types.is_numeric_dtype(dtype):
desc = df[col].describe()
profile_summary += f"- 统计: Min={desc['min']:.2f}, Max={desc['max']:.2f}, Mean={desc['mean']:.2f}\n"
# 文本/分类列分析
elif pd.api.types.is_object_dtype(dtype) or pd.api.types.is_categorical_dtype(dtype):
unique_count = df[col].nunique()
profile_summary += f"- 唯一值数量: {unique_count}\n"
# 如果唯一值较少(<50或者看起来是分类数据显示Top分布
# 这对识别“高频问题”至关重要
if unique_count > 0:
top_n = df[col].value_counts().head(5)
top_items_str = ", ".join([f"{k}({v})" for k, v in top_n.items()])
profile_summary += f"- **TOP 5 高频值**: {top_items_str}\n"
# 时间列分析
elif pd.api.types.is_datetime64_any_dtype(dtype):
profile_summary += f"- 范围: {df[col].min()}{df[col].max()}\n"
profile_summary += "\n"
except Exception as e:
profile_summary += f"❌ 读取或分析文件失败: {str(e)}\n\n"
return profile_summary

View File

@@ -0,0 +1,219 @@
# -*- coding: utf-8 -*-
"""
Client for a per-analysis execution worker subprocess.
"""
import json
import os
import queue
import subprocess
import sys
import threading
import uuid
from typing import Any, Dict, Optional
class WorkerSessionError(RuntimeError):
"""Raised when the execution worker cannot serve a request."""
class WorkerTimeoutError(WorkerSessionError):
"""Raised when the worker does not respond within the configured timeout."""
class ExecutionSessionClient:
"""Client that proxies CodeExecutor methods to a dedicated worker process."""
def __init__(
self,
output_dir: str,
allowed_files=None,
python_executable: str = None,
request_timeout_seconds: float = 60.0,
startup_timeout_seconds: float = 180.0,
):
self.output_dir = os.path.abspath(output_dir)
self.allowed_files = [os.path.abspath(path) for path in (allowed_files or [])]
self.allowed_read_roots = sorted(
{os.path.dirname(path) for path in self.allowed_files}
)
self.python_executable = python_executable or sys.executable
self.request_timeout_seconds = request_timeout_seconds
self.startup_timeout_seconds = startup_timeout_seconds
self._process: Optional[subprocess.Popen] = None
self._stderr_handle = None
self._start_worker()
self._request(
"init_session",
{
"output_dir": self.output_dir,
"variables": {
"session_output_dir": self.output_dir,
"allowed_files": self.allowed_files,
"allowed_read_roots": self.allowed_read_roots,
},
},
timeout_seconds=self.startup_timeout_seconds,
)
def execute_code(self, code: str) -> Dict[str, Any]:
return self._request("execute_code", {"code": code})
def set_variable(self, name: str, value: Any) -> None:
self._request("set_variable", {"name": name, "value": value})
def get_environment_info(self) -> str:
payload = self._request("get_environment_info", {})
return payload.get("environment_info", "")
def reset_environment(self) -> None:
self._request("reset_environment", {})
self.set_variable("session_output_dir", self.output_dir)
self.set_variable("allowed_files", self.allowed_files)
self.set_variable("allowed_read_roots", self.allowed_read_roots)
def ping(self) -> bool:
payload = self._request("ping", {})
return bool(payload.get("alive"))
def close(self) -> None:
if self._process is None:
return
try:
if self._process.poll() is None:
self._request("shutdown", {}, timeout_seconds=5)
except Exception:
pass
finally:
self._teardown_worker()
def _start_worker(self) -> None:
runtime_dir = os.path.join(self.output_dir, ".worker_runtime")
mpl_dir = os.path.join(runtime_dir, "mplconfig")
ipython_dir = os.path.join(runtime_dir, "ipython")
os.makedirs(mpl_dir, exist_ok=True)
os.makedirs(ipython_dir, exist_ok=True)
stderr_log_path = os.path.join(self.output_dir, "execution_worker.log")
self._stderr_handle = open(stderr_log_path, "a", encoding="utf-8")
worker_script = os.path.join(
os.path.dirname(__file__),
"execution_worker.py",
)
env = os.environ.copy()
env["MPLCONFIGDIR"] = mpl_dir
env["IPYTHONDIR"] = ipython_dir
env.setdefault("PYTHONIOENCODING", "utf-8")
self._process = subprocess.Popen(
[self.python_executable, worker_script],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=self._stderr_handle,
text=True,
encoding="utf-8",
cwd=os.path.abspath(os.path.join(os.path.dirname(__file__), "..")),
env=env,
bufsize=1,
)
def _request(
self,
request_type: str,
payload: Dict[str, Any],
timeout_seconds: float = None,
) -> Dict[str, Any]:
if self._process is None or self._process.stdin is None or self._process.stdout is None:
raise WorkerSessionError("执行子进程未启动")
if self._process.poll() is not None:
raise WorkerSessionError(
f"执行子进程已退出,退出码: {self._process.returncode}"
)
request_id = str(uuid.uuid4())
message = {
"request_id": request_id,
"type": request_type,
"payload": payload,
}
try:
self._process.stdin.write(json.dumps(message, ensure_ascii=False) + "\n")
self._process.stdin.flush()
except BrokenPipeError as exc:
self._teardown_worker()
raise WorkerSessionError("执行子进程通信中断") from exc
effective_timeout = (
self.request_timeout_seconds
if timeout_seconds is None
else timeout_seconds
)
response_line = self._read_response_line(effective_timeout)
if not response_line:
if self._process.poll() is not None:
exit_code = self._process.returncode
self._teardown_worker()
raise WorkerSessionError(f"执行子进程已异常退出,退出码: {exit_code}")
raise WorkerSessionError("执行子进程未返回响应")
try:
response = json.loads(response_line)
except json.JSONDecodeError as exc:
raise WorkerSessionError(f"执行子进程返回了无效JSON: {response_line}") from exc
if response.get("request_id") != request_id:
raise WorkerSessionError("执行子进程响应 request_id 不匹配")
if response.get("status") != "ok":
raise WorkerSessionError(response.get("error", "执行子进程返回未知错误"))
return response.get("payload", {})
def _read_response_line(self, timeout_seconds: float) -> str:
assert self._process is not None and self._process.stdout is not None
response_queue: queue.Queue = queue.Queue(maxsize=1)
def _reader() -> None:
try:
response_queue.put((True, self._process.stdout.readline()))
except Exception as exc:
response_queue.put((False, exc))
thread = threading.Thread(target=_reader, daemon=True)
thread.start()
try:
success, value = response_queue.get(timeout=timeout_seconds)
except queue.Empty as exc:
self._teardown_worker(force=True)
raise WorkerTimeoutError(
f"执行子进程在 {timeout_seconds:.1f} 秒内未响应,已终止当前会话"
) from exc
if success:
return value
self._teardown_worker()
raise WorkerSessionError(f"读取执行子进程响应失败: {value}")
def _teardown_worker(self, force: bool = False) -> None:
if self._process is not None and self._process.poll() is None:
self._process.terminate()
try:
self._process.wait(timeout=5)
except subprocess.TimeoutExpired:
self._process.kill()
self._process.wait(timeout=5)
if self._stderr_handle is not None:
self._stderr_handle.close()
self._stderr_handle = None
self._process = None
def __del__(self):
self.close()

321
utils/execution_worker.py Normal file
View File

@@ -0,0 +1,321 @@
# -*- coding: utf-8 -*-
"""
Subprocess worker that hosts a single CodeExecutor instance for one analysis session.
"""
import json
import os
import sys
import traceback
from pathlib import Path
from contextlib import redirect_stderr, redirect_stdout
from io import StringIO
from typing import Any, Dict, Iterable, Optional
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
class WorkerProtocolError(RuntimeError):
"""Raised when the worker receives an invalid protocol message."""
class FileAccessPolicy:
"""Controls which files the worker may read and where it may write outputs."""
def __init__(self):
self.allowed_reads = set()
self.allowed_read_roots = set()
self.allowed_write_root = ""
@staticmethod
def _normalize(path: Any) -> str:
if isinstance(path, Path):
path = str(path)
elif hasattr(path, "__fspath__"):
path = os.fspath(path)
elif not isinstance(path, str):
raise TypeError(f"不支持的路径类型: {type(path).__name__}")
return os.path.realpath(os.path.abspath(path))
def configure(
self,
allowed_reads: Iterable[Any],
allowed_write_root: Any,
allowed_read_roots: Optional[Iterable[Any]] = None,
) -> None:
self.allowed_reads = {
self._normalize(path) for path in allowed_reads if path
}
explicit_roots = {
self._normalize(path) for path in (allowed_read_roots or []) if path
}
derived_roots = {
os.path.dirname(path) for path in self.allowed_reads
}
self.allowed_read_roots = explicit_roots | derived_roots
self.allowed_write_root = (
self._normalize(allowed_write_root) if allowed_write_root else ""
)
def ensure_readable(self, path: Any) -> str:
normalized_path = self._normalize(path)
if normalized_path in self.allowed_reads:
return normalized_path
if self._is_within_read_roots(normalized_path):
return normalized_path
if self._is_within_write_root(normalized_path):
return normalized_path
raise PermissionError(f"禁止读取未授权文件: {normalized_path}")
def ensure_writable(self, path: Any) -> str:
normalized_path = self._normalize(path)
if self._is_within_write_root(normalized_path):
return normalized_path
raise PermissionError(f"禁止写入会话目录之外的路径: {normalized_path}")
def _is_within_write_root(self, normalized_path: str) -> bool:
if not self.allowed_write_root:
return False
return normalized_path == self.allowed_write_root or normalized_path.startswith(
self.allowed_write_root + os.sep
)
def _is_within_read_roots(self, normalized_path: str) -> bool:
for root in self.allowed_read_roots:
if normalized_path == root or normalized_path.startswith(root + os.sep):
return True
return False
def _write_message(message: Dict[str, Any]) -> None:
sys.stdout.write(json.dumps(message, ensure_ascii=False) + "\n")
sys.stdout.flush()
def _write_log(text: str) -> None:
if not text:
return
sys.stderr.write(text)
if not text.endswith("\n"):
sys.stderr.write("\n")
sys.stderr.flush()
class ExecutionWorker:
"""JSON-line protocol wrapper around CodeExecutor."""
def __init__(self):
self.executor = None
self.access_policy = FileAccessPolicy()
self._patches_installed = False
def handle_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
request_id = request.get("request_id", "")
request_type = request.get("type")
payload = request.get("payload", {})
try:
if request_type == "ping":
return self._ok(request_id, {"alive": True})
if request_type == "init_session":
return self._handle_init_session(request_id, payload)
if request_type == "execute_code":
self._require_executor()
return self._ok(
request_id,
self.executor.execute_code(payload.get("code", "")),
)
if request_type == "set_variable":
self._require_executor()
self._handle_set_variable(payload["name"], payload["value"])
return self._ok(request_id, {"set": True})
if request_type == "get_environment_info":
self._require_executor()
return self._ok(
request_id,
{"environment_info": self.executor.get_environment_info()},
)
if request_type == "reset_environment":
self._require_executor()
self.executor.reset_environment()
return self._ok(request_id, {"reset": True})
if request_type == "shutdown":
return self._ok(request_id, {"shutdown": True})
raise WorkerProtocolError(f"未知请求类型: {request_type}")
except Exception as exc:
return {
"request_id": request_id,
"status": "error",
"error": str(exc),
"traceback": traceback.format_exc(),
}
def _handle_init_session(
self, request_id: str, payload: Dict[str, Any]
) -> Dict[str, Any]:
output_dir = payload.get("output_dir")
if not output_dir:
raise WorkerProtocolError("init_session 缺少 output_dir")
from utils.code_executor import CodeExecutor
self.executor = CodeExecutor(output_dir)
self.access_policy.configure(
payload.get("variables", {}).get("allowed_files", []),
output_dir,
payload.get("variables", {}).get("allowed_read_roots", []),
)
self._install_file_guards()
for name, value in payload.get("variables", {}).items():
self.executor.set_variable(name, value)
return self._ok(request_id, {"initialized": True})
def _install_file_guards(self) -> None:
if self._patches_installed:
return
import builtins
import matplotlib.figure
import matplotlib.pyplot as plt
import pandas as pd
policy = self.access_policy
original_open = builtins.open
original_read_csv = pd.read_csv
original_read_excel = pd.read_excel
original_to_csv = pd.DataFrame.to_csv
original_to_excel = pd.DataFrame.to_excel
original_plt_savefig = plt.savefig
original_figure_savefig = matplotlib.figure.Figure.savefig
def guarded_open(file, mode="r", *args, **kwargs):
if isinstance(file, (str, Path)) or hasattr(file, "__fspath__"):
if any(flag in mode for flag in ("w", "a", "x", "+")):
policy.ensure_writable(file)
else:
policy.ensure_readable(file)
return original_open(file, mode, *args, **kwargs)
def guarded_read_csv(filepath_or_buffer, *args, **kwargs):
if isinstance(filepath_or_buffer, (str, Path)) or hasattr(
filepath_or_buffer, "__fspath__"
):
policy.ensure_readable(filepath_or_buffer)
return original_read_csv(filepath_or_buffer, *args, **kwargs)
def guarded_read_excel(io, *args, **kwargs):
if isinstance(io, (str, Path)) or hasattr(io, "__fspath__"):
policy.ensure_readable(io)
return original_read_excel(io, *args, **kwargs)
def guarded_to_csv(df, path_or_buf=None, *args, **kwargs):
if isinstance(path_or_buf, (str, Path)) or hasattr(path_or_buf, "__fspath__"):
policy.ensure_writable(path_or_buf)
return original_to_csv(df, path_or_buf, *args, **kwargs)
def guarded_to_excel(df, excel_writer, *args, **kwargs):
if isinstance(excel_writer, (str, Path)) or hasattr(excel_writer, "__fspath__"):
policy.ensure_writable(excel_writer)
return original_to_excel(df, excel_writer, *args, **kwargs)
def guarded_savefig(*args, **kwargs):
target = args[0] if args else kwargs.get("fname")
if target is not None and (
isinstance(target, (str, Path)) or hasattr(target, "__fspath__")
):
policy.ensure_writable(target)
return original_plt_savefig(*args, **kwargs)
def guarded_figure_savefig(fig, fname, *args, **kwargs):
if isinstance(fname, (str, Path)) or hasattr(fname, "__fspath__"):
policy.ensure_writable(fname)
return original_figure_savefig(fig, fname, *args, **kwargs)
builtins.open = guarded_open
pd.read_csv = guarded_read_csv
pd.read_excel = guarded_read_excel
pd.DataFrame.to_csv = guarded_to_csv
pd.DataFrame.to_excel = guarded_to_excel
plt.savefig = guarded_savefig
matplotlib.figure.Figure.savefig = guarded_figure_savefig
self._patches_installed = True
def _require_executor(self) -> None:
if self.executor is None:
raise WorkerProtocolError("执行会话尚未初始化")
def _handle_set_variable(self, name: str, value: Any) -> None:
self.executor.set_variable(name, value)
if name == "allowed_files":
self.access_policy.configure(
value,
self.access_policy.allowed_write_root,
self.access_policy.allowed_read_roots,
)
elif name == "allowed_read_roots":
self.access_policy.configure(
self.access_policy.allowed_reads,
self.access_policy.allowed_write_root,
value,
)
elif name == "session_output_dir":
self.access_policy.configure(
self.access_policy.allowed_reads,
value,
self.access_policy.allowed_read_roots,
)
@staticmethod
def _ok(request_id: str, payload: Dict[str, Any]) -> Dict[str, Any]:
return {
"request_id": request_id,
"status": "ok",
"payload": payload,
}
def main() -> int:
worker = ExecutionWorker()
for raw_line in sys.stdin:
raw_line = raw_line.strip()
if not raw_line:
continue
try:
request = json.loads(raw_line)
except json.JSONDecodeError as exc:
_write_message(
{
"request_id": "",
"status": "error",
"error": f"无效JSON请求: {exc}",
}
)
continue
captured_stdout = StringIO()
captured_stderr = StringIO()
with redirect_stdout(captured_stdout), redirect_stderr(captured_stderr):
response = worker.handle_request(request)
_write_log(captured_stdout.getvalue())
_write_log(captured_stderr.getvalue())
_write_message(response)
if request.get("type") == "shutdown" and response.get("status") == "ok":
return 0
return 0
if __name__ == "__main__":
raise SystemExit(main())

38
utils/extract_code.py Normal file
View File

@@ -0,0 +1,38 @@
from typing import Optional
import yaml
def extract_code_from_response(response: str) -> Optional[str]:
"""从LLM响应中提取代码"""
try:
# 尝试解析YAML
if '```yaml' in response:
start = response.find('```yaml') + 7
end = response.find('```', start)
yaml_content = response[start:end].strip()
elif '```' in response:
start = response.find('```') + 3
end = response.find('```', start)
yaml_content = response[start:end].strip()
else:
yaml_content = response.strip()
yaml_data = yaml.safe_load(yaml_content)
if 'code' in yaml_data:
return yaml_data['code']
except:
pass
# 如果YAML解析失败尝试提取```python代码块
if '```python' in response:
start = response.find('```python') + 9
end = response.find('```', start)
if end != -1:
return response[start:end].strip()
elif '```' in response:
start = response.find('```') + 3
end = response.find('```', start)
if end != -1:
return response[start:end].strip()
return None

View File

@@ -0,0 +1,230 @@
# -*- coding: utf-8 -*-
import asyncio
from typing import Optional, Any, Mapping, Dict
from openai import AsyncOpenAI, APIStatusError, APIConnectionError, APITimeoutError, APIError
from openai.types.chat import ChatCompletion
class AsyncFallbackOpenAIClient:
"""
一个支持备用 API 自动切换的异步 OpenAI 客户端。
当主 API 调用因特定错误(如内容过滤)失败时,会自动尝试使用备用 API。
"""
def __init__(
self,
primary_api_key: str,
primary_base_url: str,
primary_model_name: str,
fallback_api_key: Optional[str] = None,
fallback_base_url: Optional[str] = None,
fallback_model_name: Optional[str] = None,
primary_client_args: Optional[Dict[str, Any]] = None,
fallback_client_args: Optional[Dict[str, Any]] = None,
content_filter_error_code: str = "1301", # 特定于 Zhipu 的内容过滤错误代码
content_filter_error_field: str = "contentFilter", # 特定于 Zhipu 的内容过滤错误字段
max_retries_primary: int = 1, # 主API重试次数
max_retries_fallback: int = 1, # 备用API重试次数
retry_delay_seconds: float = 1.0 # 重试延迟时间
):
"""
初始化 AsyncFallbackOpenAIClient。
Args:
primary_api_key: 主 API 的密钥。
primary_base_url: 主 API 的基础 URL。
primary_model_name: 主 API 使用的模型名称。
fallback_api_key: 备用 API 的密钥 (可选)。
fallback_base_url: 备用 API 的基础 URL (可选)。
fallback_model_name: 备用 API 使用的模型名称 (可选)。
primary_client_args: 传递给主 AsyncOpenAI 客户端的其他参数。
fallback_client_args: 传递给备用 AsyncOpenAI 客户端的其他参数。
content_filter_error_code: 触发回退的内容过滤错误的特定错误代码。
content_filter_error_field: 触发回退的内容过滤错误中存在的字段名。
max_retries_primary: 主 API 失败时的最大重试次数。
max_retries_fallback: 备用 API 失败时的最大重试次数。
retry_delay_seconds: 重试前的延迟时间(秒)。
"""
if not primary_api_key or not primary_base_url:
raise ValueError("主 API 密钥和基础 URL 不能为空。")
_primary_args = primary_client_args or {}
self.primary_client = AsyncOpenAI(api_key=primary_api_key, base_url=primary_base_url, **_primary_args)
self.primary_model_name = primary_model_name
self.fallback_client: Optional[AsyncOpenAI] = None
self.fallback_model_name: Optional[str] = None
if fallback_api_key and fallback_base_url and fallback_model_name:
_fallback_args = fallback_client_args or {}
self.fallback_client = AsyncOpenAI(api_key=fallback_api_key, base_url=fallback_base_url, **_fallback_args)
self.fallback_model_name = fallback_model_name
else:
print("⚠️ 警告: 未完全配置备用 API 客户端。如果主 API 失败,将无法进行回退。")
self.content_filter_error_code = content_filter_error_code
self.content_filter_error_field = content_filter_error_field
self.max_retries_primary = max_retries_primary
self.max_retries_fallback = max_retries_fallback
self.retry_delay_seconds = retry_delay_seconds
self._closed = False
async def _attempt_api_call(
self,
client: AsyncOpenAI,
model_name: str,
messages: list[Mapping[str, Any]],
max_retries: int,
api_name: str,
**kwargs: Any
) -> ChatCompletion:
"""
尝试调用指定的 OpenAI API 客户端,并进行重试。
"""
last_exception = None
for attempt in range(max_retries + 1):
try:
# print(f"尝试使用 {api_name} API ({client.base_url}) 模型: {kwargs.get('model', model_name)}, 第 {attempt + 1} 次尝试")
completion = await client.chat.completions.create(
model=kwargs.pop('model', model_name),
messages=messages,
**kwargs
)
return completion
except (APIConnectionError, APITimeoutError) as e: # 通常可以重试的网络错误
last_exception = e
print(f"⚠️ {api_name} API 调用时发生可重试错误 ({type(e).__name__}): {e}. 尝试次数 {attempt + 1}/{max_retries + 1}")
if attempt < max_retries:
await asyncio.sleep(self.retry_delay_seconds * (attempt + 1)) # 增加延迟
else:
print(f"{api_name} API 在达到最大重试次数后仍然失败。")
except APIStatusError as e: # API 返回的特定状态码错误
is_content_filter_error = False
if e.status_code == 400:
try:
error_json = e.response.json()
error_details = error_json.get("error", {})
if (error_details.get("code") == self.content_filter_error_code and
self.content_filter_error_field in error_json):
is_content_filter_error = True
except Exception:
pass # 解析错误响应失败,不认为是内容过滤错误
if is_content_filter_error and api_name == "": # 如果是主 API 的内容过滤错误,则直接抛出以便回退
raise e
last_exception = e
print(f"⚠️ {api_name} API 调用时发生 APIStatusError ({e.status_code}): {e}. 尝试次数 {attempt + 1}/{max_retries + 1}")
if attempt < max_retries:
await asyncio.sleep(self.retry_delay_seconds * (attempt + 1))
else:
print(f"{api_name} API 在达到最大重试次数后仍然失败 (APIStatusError)。")
except APIError as e: # 其他不可轻易重试的 OpenAI 错误
last_exception = e
print(f"{api_name} API 调用时发生不可重试错误 ({type(e).__name__}): {e}")
break # 不再重试此类错误
if last_exception:
raise last_exception
raise RuntimeError(f"{api_name} API 调用意外失败。") # 理论上不应到达这里
async def chat_completions_create(
self,
messages: list[Mapping[str, Any]],
**kwargs: Any # 用于传递其他 OpenAI 参数,如 max_tokens, temperature 等。
) -> ChatCompletion:
"""
使用主 API 创建聊天补全,如果发生特定内容过滤错误或主 API 调用失败,则回退到备用 API。
支持对主 API 和备用 API 的可重试错误进行重试。
Args:
messages: OpenAI API 的消息列表。
**kwargs: 传递给 OpenAI API 调用的其他参数。
Returns:
ChatCompletion 对象。
Raises:
APIError: 如果主 API 和备用 API (如果尝试) 都返回 API 错误。
RuntimeError: 如果客户端已关闭。
"""
if self._closed:
raise RuntimeError("客户端已关闭。")
try:
completion = await self._attempt_api_call(
client=self.primary_client,
model_name=self.primary_model_name,
messages=messages,
max_retries=self.max_retries_primary,
api_name="",
**kwargs.copy()
)
return completion
except APIStatusError as e_primary:
is_content_filter_error = False
if e_primary.status_code == 400:
try:
error_json = e_primary.response.json()
error_details = error_json.get("error", {})
if (error_details.get("code") == self.content_filter_error_code and
self.content_filter_error_field in error_json):
is_content_filter_error = True
except Exception:
pass
if is_content_filter_error and self.fallback_client and self.fallback_model_name:
print(f" 主 API 内容过滤错误 ({e_primary.status_code})。尝试切换到备用 API ({self.fallback_client.base_url})...")
try:
fallback_completion = await self._attempt_api_call(
client=self.fallback_client,
model_name=self.fallback_model_name,
messages=messages,
max_retries=self.max_retries_fallback,
api_name="备用",
**kwargs.copy()
)
print(f"✅ 备用 API 调用成功。")
return fallback_completion
except APIError as e_fallback:
print(f"❌ 备用 API 调用最终失败: {type(e_fallback).__name__} - {e_fallback}")
raise e_fallback
else:
if not (self.fallback_client and self.fallback_model_name and is_content_filter_error):
# 如果不是内容过滤错误或者没有可用的备用API则记录主API的原始错误
print(f" 主 API 错误 ({type(e_primary).__name__}: {e_primary}), 且不满足备用条件或备用API未配置。")
raise e_primary
except APIError as e_primary_other:
print(f"❌ 主 API 调用最终失败 (非内容过滤,错误类型: {type(e_primary_other).__name__}): {e_primary_other}")
if self.fallback_client and self.fallback_model_name:
print(f" 主 API 失败,尝试切换到备用 API ({self.fallback_client.base_url})...")
try:
fallback_completion = await self._attempt_api_call(
client=self.fallback_client,
model_name=self.fallback_model_name,
messages=messages,
max_retries=self.max_retries_fallback,
api_name="备用",
**kwargs.copy()
)
print(f"✅ 备用 API 调用成功。")
return fallback_completion
except APIError as e_fallback_after_primary_fail:
print(f"❌ 备用 API 在主 API 失败后也调用失败: {type(e_fallback_after_primary_fail).__name__} - {e_fallback_after_primary_fail}")
raise e_fallback_after_primary_fail
else:
raise e_primary_other
async def close(self):
"""异步关闭主客户端和备用客户端 (如果存在)。"""
if not self._closed:
await self.primary_client.close()
if self.fallback_client:
await self.fallback_client.close()
self._closed = True
# print("AsyncFallbackOpenAIClient 已关闭。")
async def __aenter__(self):
if self._closed:
raise RuntimeError("AsyncFallbackOpenAIClient 不能在关闭后重新进入。请创建一个新实例。")
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()

View File

@@ -0,0 +1,25 @@
from typing import Any, Dict
def format_execution_result(result: Dict[str, Any]) -> str:
"""格式化执行结果为用户可读的反馈"""
feedback = []
if result['success']:
feedback.append("✅ 代码执行成功")
if result['output']:
feedback.append(f"📊 输出结果:\n{result['output']}")
if result.get('variables'):
feedback.append("📋 新生成的变量:")
for var_name, var_info in result['variables'].items():
feedback.append(f" - {var_name}: {var_info}")
else:
feedback.append("❌ 代码执行失败")
feedback.append(f"错误信息: {result['error']}")
if result['output']:
feedback.append(f"部分输出: {result['output']}")
return "\n".join(feedback)

91
utils/llm_helper.py Normal file
View File

@@ -0,0 +1,91 @@
# -*- coding: utf-8 -*-
"""
LLM调用辅助模块
"""
import asyncio
import yaml
from config.llm_config import LLMConfig
from utils.fallback_openai_client import AsyncFallbackOpenAIClient
class LLMCallError(RuntimeError):
"""Raised when the configured LLM backend cannot complete a request."""
class LLMHelper:
"""LLM调用辅助类支持同步和异步调用"""
def __init__(self, config: LLMConfig = None):
self.config = config or LLMConfig()
self.config.validate()
self.client = AsyncFallbackOpenAIClient(
primary_api_key=self.config.api_key,
primary_base_url=self.config.base_url,
primary_model_name=self.config.model
)
async def async_call(self, prompt: str, system_prompt: str = None, max_tokens: int = None, temperature: float = None) -> str:
"""异步调用LLM"""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
kwargs = {}
if max_tokens is not None:
kwargs['max_tokens'] = max_tokens
else:
kwargs['max_tokens'] = self.config.max_tokens
if temperature is not None:
kwargs['temperature'] = temperature
else:
kwargs['temperature'] = self.config.temperature
try:
response = await self.client.chat_completions_create(
messages=messages,
**kwargs
)
return response.choices[0].message.content
except Exception as e:
raise LLMCallError(f"LLM调用失败: {e}") from e
def call(self, prompt: str, system_prompt: str = None, max_tokens: int = None, temperature: float = None) -> str:
"""同步调用LLM"""
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
import nest_asyncio
nest_asyncio.apply()
return loop.run_until_complete(self.async_call(prompt, system_prompt, max_tokens, temperature))
def parse_yaml_response(self, response: str) -> dict:
"""解析YAML格式的响应"""
try:
# 提取```yaml和```之间的内容
if '```yaml' in response:
start = response.find('```yaml') + 7
end = response.find('```', start)
yaml_content = response[start:end].strip()
elif '```' in response:
start = response.find('```') + 3
end = response.find('```', start)
yaml_content = response[start:end].strip()
else:
yaml_content = response.strip()
return yaml.safe_load(yaml_content)
except Exception as e:
print(f"YAML解析失败: {e}")
print(f"原始响应: {response}")
return {}
async def close(self):
"""关闭客户端"""
await self.client.close()