feat: adjust report format and enforce image persistence
This commit is contained in:
10
utils/__init__.py
Normal file
10
utils/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
工具模块初始化文件
|
||||
"""
|
||||
|
||||
from utils.code_executor import CodeExecutor
|
||||
from utils.llm_helper import LLMHelper
|
||||
from utils.fallback_openai_client import AsyncFallbackOpenAIClient
|
||||
|
||||
__all__ = ["CodeExecutor", "LLMHelper", "AsyncFallbackOpenAIClient"]
|
||||
459
utils/code_executor.py
Normal file
459
utils/code_executor.py
Normal file
@@ -0,0 +1,459 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
安全的代码执行器,基于 IPython 提供 notebook 环境下的代码执行功能
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import ast
|
||||
import traceback
|
||||
import io
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from contextlib import redirect_stdout, redirect_stderr
|
||||
from IPython.core.interactiveshell import InteractiveShell
|
||||
from IPython.utils.capture import capture_output
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.font_manager as fm
|
||||
|
||||
|
||||
class CodeExecutor:
|
||||
"""
|
||||
安全的代码执行器,限制依赖库,捕获输出,支持图片保存与路径输出
|
||||
"""
|
||||
|
||||
ALLOWED_IMPORTS = {
|
||||
"pandas",
|
||||
"pd",
|
||||
"numpy",
|
||||
"np",
|
||||
"matplotlib",
|
||||
"matplotlib.pyplot",
|
||||
"plt",
|
||||
"seaborn",
|
||||
"sns",
|
||||
"duckdb",
|
||||
"scipy",
|
||||
"sklearn",
|
||||
"statsmodels",
|
||||
"plotly",
|
||||
"dash",
|
||||
"requests",
|
||||
"urllib",
|
||||
"os",
|
||||
"sys",
|
||||
"json",
|
||||
"csv",
|
||||
"datetime",
|
||||
"time",
|
||||
"math",
|
||||
"statistics",
|
||||
"re",
|
||||
"pathlib",
|
||||
"io",
|
||||
"collections",
|
||||
"itertools",
|
||||
"functools",
|
||||
"operator",
|
||||
"warnings",
|
||||
"logging",
|
||||
"copy",
|
||||
"pickle",
|
||||
"gzip",
|
||||
"zipfile",
|
||||
"yaml",
|
||||
"typing",
|
||||
"dataclasses",
|
||||
"enum",
|
||||
"sqlite3",
|
||||
"jieba",
|
||||
"wordcloud",
|
||||
"PIL",
|
||||
"random",
|
||||
"networkx",
|
||||
}
|
||||
|
||||
def __init__(self, output_dir: str = "outputs"):
|
||||
"""
|
||||
初始化代码执行器
|
||||
|
||||
Args:
|
||||
output_dir: 输出目录,用于保存图片和文件
|
||||
"""
|
||||
self.output_dir = os.path.abspath(output_dir)
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
|
||||
# 初始化 IPython shell
|
||||
self.shell = InteractiveShell.instance()
|
||||
|
||||
# 设置中文字体
|
||||
self._setup_chinese_font()
|
||||
|
||||
# 预导入常用库
|
||||
self._setup_common_imports()
|
||||
|
||||
# 图片计数器
|
||||
self.image_counter = 0
|
||||
|
||||
def _setup_chinese_font(self):
|
||||
"""设置matplotlib中文字体显示"""
|
||||
try:
|
||||
# 设置matplotlib使用Agg backend避免GUI问题
|
||||
matplotlib.use("Agg")
|
||||
|
||||
# 获取系统可用字体
|
||||
available_fonts = [f.name for f in fm.fontManager.ttflist]
|
||||
|
||||
# 设置matplotlib使用系统可用中文字体
|
||||
# macOS系统常用中文字体(按优先级排序)
|
||||
chinese_fonts = [
|
||||
"Hiragino Sans GB", # macOS中文简体
|
||||
"Songti SC", # macOS宋体简体
|
||||
"PingFang SC", # macOS苹方简体
|
||||
"Heiti SC", # macOS黑体简体
|
||||
"Heiti TC", # macOS黑体繁体
|
||||
"PingFang HK", # macOS苹方香港
|
||||
"SimHei", # Windows黑体
|
||||
"STHeiti", # 华文黑体
|
||||
"WenQuanYi Micro Hei", # Linux文泉驿微米黑
|
||||
"DejaVu Sans", # 默认无衬线字体
|
||||
"Arial Unicode MS", # Arial Unicode
|
||||
]
|
||||
|
||||
# 检查系统中实际存在的字体
|
||||
system_chinese_fonts = [
|
||||
font for font in chinese_fonts if font in available_fonts
|
||||
]
|
||||
|
||||
# 如果没有找到合适的中文字体,尝试更宽松的搜索
|
||||
if not system_chinese_fonts:
|
||||
print("警告:未找到精确匹配的中文字体,尝试更宽松的搜索...")
|
||||
# 更宽松的字体匹配(包含部分名称)
|
||||
fallback_fonts = []
|
||||
for available_font in available_fonts:
|
||||
if any(
|
||||
keyword in available_font
|
||||
for keyword in [
|
||||
"Hei",
|
||||
"Song",
|
||||
"Fang",
|
||||
"Kai",
|
||||
"Hiragino",
|
||||
"PingFang",
|
||||
"ST",
|
||||
]
|
||||
):
|
||||
fallback_fonts.append(available_font)
|
||||
|
||||
if fallback_fonts:
|
||||
system_chinese_fonts = fallback_fonts[:3] # 取前3个匹配的字体
|
||||
print(f"找到备选中文字体: {system_chinese_fonts}")
|
||||
else:
|
||||
print("警告:系统中未找到合适的中文字体,使用系统默认字体")
|
||||
system_chinese_fonts = ["DejaVu Sans", "Arial Unicode MS"]
|
||||
|
||||
# 设置字体配置
|
||||
plt.rcParams["font.sans-serif"] = system_chinese_fonts + [
|
||||
"DejaVu Sans",
|
||||
"Arial Unicode MS",
|
||||
]
|
||||
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
plt.rcParams["font.family"] = "sans-serif"
|
||||
|
||||
# 在shell中也设置相同的字体配置
|
||||
font_list_str = str(
|
||||
system_chinese_fonts + ["DejaVu Sans", "Arial Unicode MS"]
|
||||
)
|
||||
self.shell.run_cell(
|
||||
f"""
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.font_manager as fm
|
||||
|
||||
# 设置中文字体
|
||||
plt.rcParams['font.sans-serif'] = {font_list_str}
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
plt.rcParams['font.family'] = 'sans-serif'
|
||||
|
||||
# 确保matplotlib缓存目录可写
|
||||
import os
|
||||
cache_dir = os.path.expanduser('~/.matplotlib')
|
||||
if not os.path.exists(cache_dir):
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
os.environ['MPLCONFIGDIR'] = cache_dir
|
||||
"""
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"设置中文字体失败: {e}")
|
||||
# 即使失败也要设置基本的matplotlib配置
|
||||
try:
|
||||
matplotlib.use("Agg")
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
except:
|
||||
pass
|
||||
|
||||
def _setup_common_imports(self):
|
||||
"""预导入常用库"""
|
||||
common_imports = """
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import duckdb
|
||||
import os
|
||||
import json
|
||||
from IPython.display import display
|
||||
"""
|
||||
try:
|
||||
self.shell.run_cell(common_imports)
|
||||
# 确保display函数在shell的用户命名空间中可用
|
||||
from IPython.display import display
|
||||
|
||||
self.shell.user_ns["display"] = display
|
||||
except Exception as e:
|
||||
print(f"预导入库失败: {e}")
|
||||
|
||||
def _check_code_safety(self, code: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
检查代码安全性,限制导入的库
|
||||
|
||||
Returns:
|
||||
(is_safe, error_message)
|
||||
"""
|
||||
try:
|
||||
tree = ast.parse(code)
|
||||
except SyntaxError as e:
|
||||
return False, f"语法错误: {e}"
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
if alias.name not in self.ALLOWED_IMPORTS:
|
||||
return False, f"不允许的导入: {alias.name}"
|
||||
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
if node.module not in self.ALLOWED_IMPORTS:
|
||||
return False, f"不允许的导入: {node.module}"
|
||||
|
||||
# 检查属性访问(防止通过os.system等方式绕过)
|
||||
elif isinstance(node, ast.Attribute):
|
||||
# 检查是否访问了os模块的属性
|
||||
if isinstance(node.value, ast.Name) and node.value.id == "os":
|
||||
# 允许的os子模块和函数白名单
|
||||
allowed_os_attributes = {
|
||||
"path", "environ", "getcwd", "listdir", "makedirs", "mkdir", "remove", "rmdir",
|
||||
"path.join", "path.exists", "path.abspath", "path.dirname",
|
||||
"path.basename", "path.splitext", "path.isdir", "path.isfile",
|
||||
"sep", "name", "linesep", "stat", "getpid"
|
||||
}
|
||||
|
||||
# 检查直接属性访问 (如 os.getcwd)
|
||||
if node.attr not in allowed_os_attributes:
|
||||
# 进一步检查如果是 os.path.xxx 这种形式
|
||||
# Note: ast.Attribute 嵌套结构比较复杂,简单处理只允许 os.path 和上述白名单
|
||||
if node.attr == "path":
|
||||
pass # 允许访问 os.path
|
||||
else:
|
||||
return False, f"不允许的os属性访问: os.{node.attr}"
|
||||
|
||||
# 检查危险函数调用
|
||||
elif isinstance(node, ast.Call):
|
||||
if isinstance(node.func, ast.Name):
|
||||
if node.func.id in ["exec", "eval", "open", "__import__"]:
|
||||
return False, f"不允许的函数调用: {node.func.id}"
|
||||
|
||||
return True, ""
|
||||
|
||||
def get_current_figures_info(self) -> List[Dict[str, Any]]:
|
||||
"""获取当前matplotlib图形信息,但不自动保存"""
|
||||
figures_info = []
|
||||
|
||||
# 获取当前所有图形
|
||||
fig_nums = plt.get_fignums()
|
||||
|
||||
for fig_num in fig_nums:
|
||||
fig = plt.figure(fig_num)
|
||||
if fig.get_axes(): # 只处理有内容的图形
|
||||
figures_info.append(
|
||||
{
|
||||
"figure_number": fig_num,
|
||||
"axes_count": len(fig.get_axes()),
|
||||
"figure_size": fig.get_size_inches().tolist(),
|
||||
"has_content": True,
|
||||
}
|
||||
)
|
||||
|
||||
return figures_info
|
||||
|
||||
def _format_table_output(self, obj: Any) -> str:
|
||||
"""格式化表格输出,限制行数"""
|
||||
if hasattr(obj, "shape") and hasattr(obj, "head"): # pandas DataFrame
|
||||
rows, cols = obj.shape
|
||||
print(f"\n数据表形状: {rows}行 x {cols}列")
|
||||
print(f"列名: {list(obj.columns)}")
|
||||
|
||||
if rows <= 15:
|
||||
return str(obj)
|
||||
else:
|
||||
head_part = obj.head(5)
|
||||
tail_part = obj.tail(5)
|
||||
return f"{head_part}\n...\n(省略 {rows-10} 行)\n...\n{tail_part}"
|
||||
|
||||
return str(obj)
|
||||
|
||||
def execute_code(self, code: str) -> Dict[str, Any]:
|
||||
"""
|
||||
执行代码并返回结果
|
||||
|
||||
Args:
|
||||
code: 要执行的Python代码
|
||||
|
||||
Returns:
|
||||
{
|
||||
'success': bool,
|
||||
'output': str,
|
||||
'error': str,
|
||||
'variables': Dict[str, Any] # 新生成的重要变量
|
||||
}
|
||||
"""
|
||||
# 检查代码安全性
|
||||
is_safe, safety_error = self._check_code_safety(code)
|
||||
if not is_safe:
|
||||
return {
|
||||
"success": False,
|
||||
"output": "",
|
||||
"error": f"代码安全检查失败: {safety_error}",
|
||||
"variables": {},
|
||||
}
|
||||
|
||||
# 记录执行前的变量
|
||||
vars_before = set(self.shell.user_ns.keys())
|
||||
|
||||
try:
|
||||
# 使用IPython的capture_output来捕获所有输出
|
||||
with capture_output() as captured:
|
||||
result = self.shell.run_cell(code)
|
||||
|
||||
# 检查执行结果
|
||||
if result.error_before_exec:
|
||||
error_msg = str(result.error_before_exec)
|
||||
return {
|
||||
"success": False,
|
||||
"output": captured.stdout,
|
||||
"error": f"执行前错误: {error_msg}",
|
||||
"variables": {},
|
||||
}
|
||||
|
||||
if result.error_in_exec:
|
||||
error_msg = str(result.error_in_exec)
|
||||
return {
|
||||
"success": False,
|
||||
"output": captured.stdout,
|
||||
"error": f"执行错误: {error_msg}",
|
||||
"variables": {},
|
||||
}
|
||||
|
||||
# 获取输出
|
||||
output = captured.stdout
|
||||
|
||||
# 如果有返回值,添加到输出
|
||||
if result.result is not None:
|
||||
formatted_result = self._format_table_output(result.result)
|
||||
output += f"\n{formatted_result}"
|
||||
# 记录新产生的重要变量(简化版本)
|
||||
vars_after = set(self.shell.user_ns.keys())
|
||||
new_vars = vars_after - vars_before
|
||||
|
||||
# 只记录新创建的DataFrame等重要数据结构
|
||||
important_new_vars = {}
|
||||
for var_name in new_vars:
|
||||
if not var_name.startswith("_"):
|
||||
try:
|
||||
var_value = self.shell.user_ns[var_name]
|
||||
if hasattr(var_value, "shape"): # pandas DataFrame, numpy array
|
||||
important_new_vars[var_name] = (
|
||||
f"{type(var_value).__name__} with shape {var_value.shape}"
|
||||
)
|
||||
elif var_name in ["session_output_dir"]: # 重要的配置变量
|
||||
important_new_vars[var_name] = str(var_value)
|
||||
except:
|
||||
pass
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"output": output,
|
||||
"error": "",
|
||||
"variables": important_new_vars,
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"output": captured.stdout if "captured" in locals() else "",
|
||||
"error": f"执行异常: {str(e)}\n{traceback.format_exc()}",
|
||||
"variables": {},
|
||||
}
|
||||
|
||||
def reset_environment(self):
|
||||
"""重置执行环境"""
|
||||
self.shell.reset()
|
||||
self._setup_common_imports()
|
||||
self._setup_chinese_font()
|
||||
plt.close("all")
|
||||
self.image_counter = 0
|
||||
|
||||
def set_variable(self, name: str, value: Any):
|
||||
"""设置执行环境中的变量"""
|
||||
self.shell.user_ns[name] = value
|
||||
|
||||
def get_environment_info(self) -> str:
|
||||
"""获取当前执行环境的变量信息,用于系统提示词"""
|
||||
info_parts = []
|
||||
|
||||
# 获取重要的数据变量
|
||||
important_vars = {}
|
||||
for var_name, var_value in self.shell.user_ns.items():
|
||||
if not var_name.startswith("_") and var_name not in [
|
||||
"In",
|
||||
"Out",
|
||||
"get_ipython",
|
||||
"exit",
|
||||
"quit",
|
||||
]:
|
||||
try:
|
||||
if hasattr(var_value, "shape"): # pandas DataFrame, numpy array
|
||||
important_vars[var_name] = (
|
||||
f"{type(var_value).__name__} with shape {var_value.shape}"
|
||||
)
|
||||
elif var_name in ["session_output_dir"]: # 重要的路径变量
|
||||
important_vars[var_name] = str(var_value)
|
||||
elif (
|
||||
isinstance(var_value, (int, float, str, bool))
|
||||
and len(str(var_value)) < 100
|
||||
):
|
||||
important_vars[var_name] = (
|
||||
f"{type(var_value).__name__}: {var_value}"
|
||||
)
|
||||
elif hasattr(var_value, "__module__") and var_value.__module__ in [
|
||||
"pandas",
|
||||
"numpy",
|
||||
"matplotlib.pyplot",
|
||||
]:
|
||||
important_vars[var_name] = f"导入的模块: {var_value.__module__}"
|
||||
except:
|
||||
continue
|
||||
|
||||
if important_vars:
|
||||
info_parts.append("当前环境变量:")
|
||||
for var_name, var_info in important_vars.items():
|
||||
info_parts.append(f"- {var_name}: {var_info}")
|
||||
else:
|
||||
info_parts.append("当前环境已预装pandas, numpy, matplotlib等库")
|
||||
|
||||
# 添加输出目录信息
|
||||
if "session_output_dir" in self.shell.user_ns:
|
||||
info_parts.append(
|
||||
f"图片保存目录: session_output_dir = '{self.shell.user_ns['session_output_dir']}'"
|
||||
)
|
||||
|
||||
return "\n".join(info_parts)
|
||||
15
utils/create_session_dir.py
Normal file
15
utils/create_session_dir.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def create_session_output_dir(base_output_dir, user_input: str) -> str:
|
||||
"""为本次分析创建独立的输出目录"""
|
||||
|
||||
# 使用当前时间创建唯一的会话目录名(格式:YYYYMMDD_HHMMSS)
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
session_id = timestamp
|
||||
dir_name = f"session_{session_id}"
|
||||
session_dir = os.path.join(base_output_dir, dir_name)
|
||||
os.makedirs(session_dir, exist_ok=True)
|
||||
|
||||
return session_dir
|
||||
90
utils/data_loader.py
Normal file
90
utils/data_loader.py
Normal file
@@ -0,0 +1,90 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import pandas as pd
|
||||
import io
|
||||
|
||||
def load_and_profile_data(file_paths: list) -> str:
|
||||
"""
|
||||
加载数据并生成数据画像
|
||||
|
||||
Args:
|
||||
file_paths: 文件路径列表
|
||||
|
||||
Returns:
|
||||
包含数据画像的Markdown字符串
|
||||
"""
|
||||
profile_summary = "# 数据画像报告 (Data Profile)\n\n"
|
||||
|
||||
if not file_paths:
|
||||
return profile_summary + "未提供数据文件。"
|
||||
|
||||
for file_path in file_paths:
|
||||
file_name = os.path.basename(file_path)
|
||||
profile_summary += f"## 文件: {file_name}\n\n"
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
profile_summary += f"⚠️ 文件不存在: {file_path}\n\n"
|
||||
continue
|
||||
|
||||
try:
|
||||
# 根据扩展名选择加载方式
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
if ext == '.csv':
|
||||
# 尝试多种编码
|
||||
try:
|
||||
df = pd.read_csv(file_path, encoding='utf-8')
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
df = pd.read_csv(file_path, encoding='gbk')
|
||||
except Exception:
|
||||
df = pd.read_csv(file_path, encoding='latin1')
|
||||
elif ext in ['.xlsx', '.xls']:
|
||||
df = pd.read_excel(file_path)
|
||||
else:
|
||||
profile_summary += f"⚠️ 不支持的文件格式: {ext}\n\n"
|
||||
continue
|
||||
|
||||
# 基础信息
|
||||
rows, cols = df.shape
|
||||
profile_summary += f"- **维度**: {rows} 行 x {cols} 列\n"
|
||||
profile_summary += f"- **列名**: `{', '.join(df.columns)}`\n\n"
|
||||
|
||||
profile_summary += "### 列详细分布:\n"
|
||||
|
||||
# 遍历分析每列
|
||||
for col in df.columns:
|
||||
dtype = df[col].dtype
|
||||
null_count = df[col].isnull().sum()
|
||||
null_ratio = (null_count / rows) * 100
|
||||
|
||||
profile_summary += f"#### {col} ({dtype})\n"
|
||||
if null_count > 0:
|
||||
profile_summary += f"- ⚠️ 空值: {null_count} ({null_ratio:.1f}%)\n"
|
||||
|
||||
# 数值列分析
|
||||
if pd.api.types.is_numeric_dtype(dtype):
|
||||
desc = df[col].describe()
|
||||
profile_summary += f"- 统计: Min={desc['min']:.2f}, Max={desc['max']:.2f}, Mean={desc['mean']:.2f}\n"
|
||||
|
||||
# 文本/分类列分析
|
||||
elif pd.api.types.is_object_dtype(dtype) or pd.api.types.is_categorical_dtype(dtype):
|
||||
unique_count = df[col].nunique()
|
||||
profile_summary += f"- 唯一值数量: {unique_count}\n"
|
||||
|
||||
# 如果唯一值较少(<50)或者看起来是分类数据,显示Top分布
|
||||
# 这对识别“高频问题”至关重要
|
||||
if unique_count > 0:
|
||||
top_n = df[col].value_counts().head(5)
|
||||
top_items_str = ", ".join([f"{k}({v})" for k, v in top_n.items()])
|
||||
profile_summary += f"- **TOP 5 高频值**: {top_items_str}\n"
|
||||
|
||||
# 时间列分析
|
||||
elif pd.api.types.is_datetime64_any_dtype(dtype):
|
||||
profile_summary += f"- 范围: {df[col].min()} 至 {df[col].max()}\n"
|
||||
|
||||
profile_summary += "\n"
|
||||
|
||||
except Exception as e:
|
||||
profile_summary += f"❌ 读取或分析文件失败: {str(e)}\n\n"
|
||||
|
||||
return profile_summary
|
||||
38
utils/extract_code.py
Normal file
38
utils/extract_code.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from typing import Optional
|
||||
import yaml
|
||||
|
||||
|
||||
def extract_code_from_response(response: str) -> Optional[str]:
|
||||
"""从LLM响应中提取代码"""
|
||||
try:
|
||||
# 尝试解析YAML
|
||||
if '```yaml' in response:
|
||||
start = response.find('```yaml') + 7
|
||||
end = response.find('```', start)
|
||||
yaml_content = response[start:end].strip()
|
||||
elif '```' in response:
|
||||
start = response.find('```') + 3
|
||||
end = response.find('```', start)
|
||||
yaml_content = response[start:end].strip()
|
||||
else:
|
||||
yaml_content = response.strip()
|
||||
|
||||
yaml_data = yaml.safe_load(yaml_content)
|
||||
if 'code' in yaml_data:
|
||||
return yaml_data['code']
|
||||
except:
|
||||
pass
|
||||
|
||||
# 如果YAML解析失败,尝试提取```python代码块
|
||||
if '```python' in response:
|
||||
start = response.find('```python') + 9
|
||||
end = response.find('```', start)
|
||||
if end != -1:
|
||||
return response[start:end].strip()
|
||||
elif '```' in response:
|
||||
start = response.find('```') + 3
|
||||
end = response.find('```', start)
|
||||
if end != -1:
|
||||
return response[start:end].strip()
|
||||
|
||||
return None
|
||||
230
utils/fallback_openai_client.py
Normal file
230
utils/fallback_openai_client.py
Normal file
@@ -0,0 +1,230 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import asyncio
|
||||
from typing import Optional, Any, Mapping, Dict
|
||||
from openai import AsyncOpenAI, APIStatusError, APIConnectionError, APITimeoutError, APIError
|
||||
from openai.types.chat import ChatCompletion
|
||||
|
||||
class AsyncFallbackOpenAIClient:
|
||||
"""
|
||||
一个支持备用 API 自动切换的异步 OpenAI 客户端。
|
||||
当主 API 调用因特定错误(如内容过滤)失败时,会自动尝试使用备用 API。
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
primary_api_key: str,
|
||||
primary_base_url: str,
|
||||
primary_model_name: str,
|
||||
fallback_api_key: Optional[str] = None,
|
||||
fallback_base_url: Optional[str] = None,
|
||||
fallback_model_name: Optional[str] = None,
|
||||
primary_client_args: Optional[Dict[str, Any]] = None,
|
||||
fallback_client_args: Optional[Dict[str, Any]] = None,
|
||||
content_filter_error_code: str = "1301", # 特定于 Zhipu 的内容过滤错误代码
|
||||
content_filter_error_field: str = "contentFilter", # 特定于 Zhipu 的内容过滤错误字段
|
||||
max_retries_primary: int = 1, # 主API重试次数
|
||||
max_retries_fallback: int = 1, # 备用API重试次数
|
||||
retry_delay_seconds: float = 1.0 # 重试延迟时间
|
||||
):
|
||||
"""
|
||||
初始化 AsyncFallbackOpenAIClient。
|
||||
|
||||
Args:
|
||||
primary_api_key: 主 API 的密钥。
|
||||
primary_base_url: 主 API 的基础 URL。
|
||||
primary_model_name: 主 API 使用的模型名称。
|
||||
fallback_api_key: 备用 API 的密钥 (可选)。
|
||||
fallback_base_url: 备用 API 的基础 URL (可选)。
|
||||
fallback_model_name: 备用 API 使用的模型名称 (可选)。
|
||||
primary_client_args: 传递给主 AsyncOpenAI 客户端的其他参数。
|
||||
fallback_client_args: 传递给备用 AsyncOpenAI 客户端的其他参数。
|
||||
content_filter_error_code: 触发回退的内容过滤错误的特定错误代码。
|
||||
content_filter_error_field: 触发回退的内容过滤错误中存在的字段名。
|
||||
max_retries_primary: 主 API 失败时的最大重试次数。
|
||||
max_retries_fallback: 备用 API 失败时的最大重试次数。
|
||||
retry_delay_seconds: 重试前的延迟时间(秒)。
|
||||
"""
|
||||
if not primary_api_key or not primary_base_url:
|
||||
raise ValueError("主 API 密钥和基础 URL 不能为空。")
|
||||
|
||||
_primary_args = primary_client_args or {}
|
||||
self.primary_client = AsyncOpenAI(api_key=primary_api_key, base_url=primary_base_url, **_primary_args)
|
||||
self.primary_model_name = primary_model_name
|
||||
|
||||
self.fallback_client: Optional[AsyncOpenAI] = None
|
||||
self.fallback_model_name: Optional[str] = None
|
||||
if fallback_api_key and fallback_base_url and fallback_model_name:
|
||||
_fallback_args = fallback_client_args or {}
|
||||
self.fallback_client = AsyncOpenAI(api_key=fallback_api_key, base_url=fallback_base_url, **_fallback_args)
|
||||
self.fallback_model_name = fallback_model_name
|
||||
else:
|
||||
print("⚠️ 警告: 未完全配置备用 API 客户端。如果主 API 失败,将无法进行回退。")
|
||||
|
||||
self.content_filter_error_code = content_filter_error_code
|
||||
self.content_filter_error_field = content_filter_error_field
|
||||
self.max_retries_primary = max_retries_primary
|
||||
self.max_retries_fallback = max_retries_fallback
|
||||
self.retry_delay_seconds = retry_delay_seconds
|
||||
self._closed = False
|
||||
|
||||
async def _attempt_api_call(
|
||||
self,
|
||||
client: AsyncOpenAI,
|
||||
model_name: str,
|
||||
messages: list[Mapping[str, Any]],
|
||||
max_retries: int,
|
||||
api_name: str,
|
||||
**kwargs: Any
|
||||
) -> ChatCompletion:
|
||||
"""
|
||||
尝试调用指定的 OpenAI API 客户端,并进行重试。
|
||||
"""
|
||||
last_exception = None
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
# print(f"尝试使用 {api_name} API ({client.base_url}) 模型: {kwargs.get('model', model_name)}, 第 {attempt + 1} 次尝试")
|
||||
completion = await client.chat.completions.create(
|
||||
model=kwargs.pop('model', model_name),
|
||||
messages=messages,
|
||||
**kwargs
|
||||
)
|
||||
return completion
|
||||
except (APIConnectionError, APITimeoutError) as e: # 通常可以重试的网络错误
|
||||
last_exception = e
|
||||
print(f"⚠️ {api_name} API 调用时发生可重试错误 ({type(e).__name__}): {e}. 尝试次数 {attempt + 1}/{max_retries + 1}")
|
||||
if attempt < max_retries:
|
||||
await asyncio.sleep(self.retry_delay_seconds * (attempt + 1)) # 增加延迟
|
||||
else:
|
||||
print(f"❌ {api_name} API 在达到最大重试次数后仍然失败。")
|
||||
except APIStatusError as e: # API 返回的特定状态码错误
|
||||
is_content_filter_error = False
|
||||
if e.status_code == 400:
|
||||
try:
|
||||
error_json = e.response.json()
|
||||
error_details = error_json.get("error", {})
|
||||
if (error_details.get("code") == self.content_filter_error_code and
|
||||
self.content_filter_error_field in error_json):
|
||||
is_content_filter_error = True
|
||||
except Exception:
|
||||
pass # 解析错误响应失败,不认为是内容过滤错误
|
||||
|
||||
if is_content_filter_error and api_name == "主": # 如果是主 API 的内容过滤错误,则直接抛出以便回退
|
||||
raise e
|
||||
|
||||
last_exception = e
|
||||
print(f"⚠️ {api_name} API 调用时发生 APIStatusError ({e.status_code}): {e}. 尝试次数 {attempt + 1}/{max_retries + 1}")
|
||||
if attempt < max_retries:
|
||||
await asyncio.sleep(self.retry_delay_seconds * (attempt + 1))
|
||||
else:
|
||||
print(f"❌ {api_name} API 在达到最大重试次数后仍然失败 (APIStatusError)。")
|
||||
except APIError as e: # 其他不可轻易重试的 OpenAI 错误
|
||||
last_exception = e
|
||||
print(f"❌ {api_name} API 调用时发生不可重试错误 ({type(e).__name__}): {e}")
|
||||
break # 不再重试此类错误
|
||||
|
||||
if last_exception:
|
||||
raise last_exception
|
||||
raise RuntimeError(f"{api_name} API 调用意外失败。") # 理论上不应到达这里
|
||||
|
||||
async def chat_completions_create(
|
||||
self,
|
||||
messages: list[Mapping[str, Any]],
|
||||
**kwargs: Any # 用于传递其他 OpenAI 参数,如 max_tokens, temperature 等。
|
||||
) -> ChatCompletion:
|
||||
"""
|
||||
使用主 API 创建聊天补全,如果发生特定内容过滤错误或主 API 调用失败,则回退到备用 API。
|
||||
支持对主 API 和备用 API 的可重试错误进行重试。
|
||||
|
||||
Args:
|
||||
messages: OpenAI API 的消息列表。
|
||||
**kwargs: 传递给 OpenAI API 调用的其他参数。
|
||||
|
||||
Returns:
|
||||
ChatCompletion 对象。
|
||||
|
||||
Raises:
|
||||
APIError: 如果主 API 和备用 API (如果尝试) 都返回 API 错误。
|
||||
RuntimeError: 如果客户端已关闭。
|
||||
"""
|
||||
if self._closed:
|
||||
raise RuntimeError("客户端已关闭。")
|
||||
|
||||
try:
|
||||
completion = await self._attempt_api_call(
|
||||
client=self.primary_client,
|
||||
model_name=self.primary_model_name,
|
||||
messages=messages,
|
||||
max_retries=self.max_retries_primary,
|
||||
api_name="主",
|
||||
**kwargs.copy()
|
||||
)
|
||||
return completion
|
||||
except APIStatusError as e_primary:
|
||||
is_content_filter_error = False
|
||||
if e_primary.status_code == 400:
|
||||
try:
|
||||
error_json = e_primary.response.json()
|
||||
error_details = error_json.get("error", {})
|
||||
if (error_details.get("code") == self.content_filter_error_code and
|
||||
self.content_filter_error_field in error_json):
|
||||
is_content_filter_error = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if is_content_filter_error and self.fallback_client and self.fallback_model_name:
|
||||
print(f"ℹ️ 主 API 内容过滤错误 ({e_primary.status_code})。尝试切换到备用 API ({self.fallback_client.base_url})...")
|
||||
try:
|
||||
fallback_completion = await self._attempt_api_call(
|
||||
client=self.fallback_client,
|
||||
model_name=self.fallback_model_name,
|
||||
messages=messages,
|
||||
max_retries=self.max_retries_fallback,
|
||||
api_name="备用",
|
||||
**kwargs.copy()
|
||||
)
|
||||
print(f"✅ 备用 API 调用成功。")
|
||||
return fallback_completion
|
||||
except APIError as e_fallback:
|
||||
print(f"❌ 备用 API 调用最终失败: {type(e_fallback).__name__} - {e_fallback}")
|
||||
raise e_fallback
|
||||
else:
|
||||
if not (self.fallback_client and self.fallback_model_name and is_content_filter_error):
|
||||
# 如果不是内容过滤错误,或者没有可用的备用API,则记录主API的原始错误
|
||||
print(f"ℹ️ 主 API 错误 ({type(e_primary).__name__}: {e_primary}), 且不满足备用条件或备用API未配置。")
|
||||
raise e_primary
|
||||
except APIError as e_primary_other:
|
||||
print(f"❌ 主 API 调用最终失败 (非内容过滤,错误类型: {type(e_primary_other).__name__}): {e_primary_other}")
|
||||
if self.fallback_client and self.fallback_model_name:
|
||||
print(f"ℹ️ 主 API 失败,尝试切换到备用 API ({self.fallback_client.base_url})...")
|
||||
try:
|
||||
fallback_completion = await self._attempt_api_call(
|
||||
client=self.fallback_client,
|
||||
model_name=self.fallback_model_name,
|
||||
messages=messages,
|
||||
max_retries=self.max_retries_fallback,
|
||||
api_name="备用",
|
||||
**kwargs.copy()
|
||||
)
|
||||
print(f"✅ 备用 API 调用成功。")
|
||||
return fallback_completion
|
||||
except APIError as e_fallback_after_primary_fail:
|
||||
print(f"❌ 备用 API 在主 API 失败后也调用失败: {type(e_fallback_after_primary_fail).__name__} - {e_fallback_after_primary_fail}")
|
||||
raise e_fallback_after_primary_fail
|
||||
else:
|
||||
raise e_primary_other
|
||||
|
||||
async def close(self):
|
||||
"""异步关闭主客户端和备用客户端 (如果存在)。"""
|
||||
if not self._closed:
|
||||
await self.primary_client.close()
|
||||
if self.fallback_client:
|
||||
await self.fallback_client.close()
|
||||
self._closed = True
|
||||
# print("AsyncFallbackOpenAIClient 已关闭。")
|
||||
|
||||
async def __aenter__(self):
|
||||
if self._closed:
|
||||
raise RuntimeError("AsyncFallbackOpenAIClient 不能在关闭后重新进入。请创建一个新实例。")
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.close()
|
||||
25
utils/format_execution_result.py
Normal file
25
utils/format_execution_result.py
Normal file
@@ -0,0 +1,25 @@
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
def format_execution_result(result: Dict[str, Any]) -> str:
|
||||
"""格式化执行结果为用户可读的反馈"""
|
||||
feedback = []
|
||||
|
||||
if result['success']:
|
||||
feedback.append("✅ 代码执行成功")
|
||||
|
||||
if result['output']:
|
||||
feedback.append(f"📊 输出结果:\n{result['output']}")
|
||||
|
||||
if result.get('variables'):
|
||||
feedback.append("📋 新生成的变量:")
|
||||
for var_name, var_info in result['variables'].items():
|
||||
feedback.append(f" - {var_name}: {var_info}")
|
||||
else:
|
||||
feedback.append("❌ 代码执行失败")
|
||||
feedback.append(f"错误信息: {result['error']}")
|
||||
if result['output']:
|
||||
feedback.append(f"部分输出: {result['output']}")
|
||||
|
||||
return "\n".join(feedback)
|
||||
86
utils/llm_helper.py
Normal file
86
utils/llm_helper.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
LLM调用辅助模块
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import yaml
|
||||
from config.llm_config import LLMConfig
|
||||
from utils.fallback_openai_client import AsyncFallbackOpenAIClient
|
||||
|
||||
class LLMHelper:
|
||||
"""LLM调用辅助类,支持同步和异步调用"""
|
||||
|
||||
def __init__(self, config: LLMConfig = None):
|
||||
self.config = config
|
||||
self.client = AsyncFallbackOpenAIClient(
|
||||
primary_api_key=config.api_key,
|
||||
primary_base_url=config.base_url,
|
||||
primary_model_name=config.model
|
||||
)
|
||||
|
||||
async def async_call(self, prompt: str, system_prompt: str = None, max_tokens: int = None, temperature: float = None) -> str:
|
||||
"""异步调用LLM"""
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
kwargs = {}
|
||||
if max_tokens is not None:
|
||||
kwargs['max_tokens'] = max_tokens
|
||||
else:
|
||||
kwargs['max_tokens'] = self.config.max_tokens
|
||||
|
||||
if temperature is not None:
|
||||
kwargs['temperature'] = temperature
|
||||
else:
|
||||
kwargs['temperature'] = self.config.temperature
|
||||
|
||||
try:
|
||||
response = await self.client.chat_completions_create(
|
||||
messages=messages,
|
||||
**kwargs
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
print(f"LLM调用失败: {e}")
|
||||
return ""
|
||||
|
||||
def call(self, prompt: str, system_prompt: str = None, max_tokens: int = None, temperature: float = None) -> str:
|
||||
"""同步调用LLM"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
|
||||
return loop.run_until_complete(self.async_call(prompt, system_prompt, max_tokens, temperature))
|
||||
|
||||
def parse_yaml_response(self, response: str) -> dict:
|
||||
"""解析YAML格式的响应"""
|
||||
try:
|
||||
# 提取```yaml和```之间的内容
|
||||
if '```yaml' in response:
|
||||
start = response.find('```yaml') + 7
|
||||
end = response.find('```', start)
|
||||
yaml_content = response[start:end].strip()
|
||||
elif '```' in response:
|
||||
start = response.find('```') + 3
|
||||
end = response.find('```', start)
|
||||
yaml_content = response[start:end].strip()
|
||||
else:
|
||||
yaml_content = response.strip()
|
||||
|
||||
return yaml.safe_load(yaml_content)
|
||||
except Exception as e:
|
||||
print(f"YAML解析失败: {e}")
|
||||
print(f"原始响应: {response}")
|
||||
return {}
|
||||
|
||||
async def close(self):
|
||||
"""关闭客户端"""
|
||||
await self.client.close()
|
||||
Reference in New Issue
Block a user