Files
iov_data_analysis_agent/utils/fallback_openai_client.py

231 lines
11 KiB
Python
Raw Normal View History

# -*- coding: utf-8 -*-
import asyncio
from typing import Optional, Any, Mapping, Dict
from openai import AsyncOpenAI, APIStatusError, APIConnectionError, APITimeoutError, APIError
from openai.types.chat import ChatCompletion
class AsyncFallbackOpenAIClient:
"""
一个支持备用 API 自动切换的异步 OpenAI 客户端
当主 API 调用因特定错误如内容过滤失败时会自动尝试使用备用 API
"""
def __init__(
self,
primary_api_key: str,
primary_base_url: str,
primary_model_name: str,
fallback_api_key: Optional[str] = None,
fallback_base_url: Optional[str] = None,
fallback_model_name: Optional[str] = None,
primary_client_args: Optional[Dict[str, Any]] = None,
fallback_client_args: Optional[Dict[str, Any]] = None,
content_filter_error_code: str = "1301", # 特定于 Zhipu 的内容过滤错误代码
content_filter_error_field: str = "contentFilter", # 特定于 Zhipu 的内容过滤错误字段
max_retries_primary: int = 1, # 主API重试次数
max_retries_fallback: int = 1, # 备用API重试次数
retry_delay_seconds: float = 1.0 # 重试延迟时间
):
"""
初始化 AsyncFallbackOpenAIClient
Args:
primary_api_key: API 的密钥
primary_base_url: API 的基础 URL
primary_model_name: API 使用的模型名称
fallback_api_key: 备用 API 的密钥 (可选)
fallback_base_url: 备用 API 的基础 URL (可选)
fallback_model_name: 备用 API 使用的模型名称 (可选)
primary_client_args: 传递给主 AsyncOpenAI 客户端的其他参数
fallback_client_args: 传递给备用 AsyncOpenAI 客户端的其他参数
content_filter_error_code: 触发回退的内容过滤错误的特定错误代码
content_filter_error_field: 触发回退的内容过滤错误中存在的字段名
max_retries_primary: API 失败时的最大重试次数
max_retries_fallback: 备用 API 失败时的最大重试次数
retry_delay_seconds: 重试前的延迟时间
"""
if not primary_api_key or not primary_base_url:
raise ValueError("主 API 密钥和基础 URL 不能为空。")
_primary_args = primary_client_args or {}
self.primary_client = AsyncOpenAI(api_key=primary_api_key, base_url=primary_base_url, **_primary_args)
self.primary_model_name = primary_model_name
self.fallback_client: Optional[AsyncOpenAI] = None
self.fallback_model_name: Optional[str] = None
if fallback_api_key and fallback_base_url and fallback_model_name:
_fallback_args = fallback_client_args or {}
self.fallback_client = AsyncOpenAI(api_key=fallback_api_key, base_url=fallback_base_url, **_fallback_args)
self.fallback_model_name = fallback_model_name
else:
print("⚠️ 警告: 未完全配置备用 API 客户端。如果主 API 失败,将无法进行回退。")
self.content_filter_error_code = content_filter_error_code
self.content_filter_error_field = content_filter_error_field
self.max_retries_primary = max_retries_primary
self.max_retries_fallback = max_retries_fallback
self.retry_delay_seconds = retry_delay_seconds
self._closed = False
async def _attempt_api_call(
self,
client: AsyncOpenAI,
model_name: str,
messages: list[Mapping[str, Any]],
max_retries: int,
api_name: str,
**kwargs: Any
) -> ChatCompletion:
"""
尝试调用指定的 OpenAI API 客户端并进行重试
"""
last_exception = None
for attempt in range(max_retries + 1):
try:
# print(f"尝试使用 {api_name} API ({client.base_url}) 模型: {kwargs.get('model', model_name)}, 第 {attempt + 1} 次尝试")
completion = await client.chat.completions.create(
model=kwargs.pop('model', model_name),
messages=messages,
**kwargs
)
return completion
except (APIConnectionError, APITimeoutError) as e: # 通常可以重试的网络错误
last_exception = e
print(f"⚠️ {api_name} API 调用时发生可重试错误 ({type(e).__name__}): {e}. 尝试次数 {attempt + 1}/{max_retries + 1}")
if attempt < max_retries:
await asyncio.sleep(self.retry_delay_seconds * (attempt + 1)) # 增加延迟
else:
print(f"{api_name} API 在达到最大重试次数后仍然失败。")
except APIStatusError as e: # API 返回的特定状态码错误
is_content_filter_error = False
if e.status_code == 400:
try:
error_json = e.response.json()
error_details = error_json.get("error", {})
if (error_details.get("code") == self.content_filter_error_code and
self.content_filter_error_field in error_json):
is_content_filter_error = True
except Exception:
pass # 解析错误响应失败,不认为是内容过滤错误
if is_content_filter_error and api_name == "": # 如果是主 API 的内容过滤错误,则直接抛出以便回退
raise e
last_exception = e
print(f"⚠️ {api_name} API 调用时发生 APIStatusError ({e.status_code}): {e}. 尝试次数 {attempt + 1}/{max_retries + 1}")
if attempt < max_retries:
await asyncio.sleep(self.retry_delay_seconds * (attempt + 1))
else:
print(f"{api_name} API 在达到最大重试次数后仍然失败 (APIStatusError)。")
except APIError as e: # 其他不可轻易重试的 OpenAI 错误
last_exception = e
print(f"{api_name} API 调用时发生不可重试错误 ({type(e).__name__}): {e}")
break # 不再重试此类错误
if last_exception:
raise last_exception
raise RuntimeError(f"{api_name} API 调用意外失败。") # 理论上不应到达这里
async def chat_completions_create(
self,
messages: list[Mapping[str, Any]],
**kwargs: Any # 用于传递其他 OpenAI 参数,如 max_tokens, temperature 等。
) -> ChatCompletion:
"""
使用主 API 创建聊天补全如果发生特定内容过滤错误或主 API 调用失败则回退到备用 API
支持对主 API 和备用 API 的可重试错误进行重试
Args:
messages: OpenAI API 的消息列表
**kwargs: 传递给 OpenAI API 调用的其他参数
Returns:
ChatCompletion 对象
Raises:
APIError: 如果主 API 和备用 API (如果尝试) 都返回 API 错误
RuntimeError: 如果客户端已关闭
"""
if self._closed:
raise RuntimeError("客户端已关闭。")
try:
completion = await self._attempt_api_call(
client=self.primary_client,
model_name=self.primary_model_name,
messages=messages,
max_retries=self.max_retries_primary,
api_name="",
**kwargs.copy()
)
return completion
except APIStatusError as e_primary:
is_content_filter_error = False
if e_primary.status_code == 400:
try:
error_json = e_primary.response.json()
error_details = error_json.get("error", {})
if (error_details.get("code") == self.content_filter_error_code and
self.content_filter_error_field in error_json):
is_content_filter_error = True
except Exception:
pass
if is_content_filter_error and self.fallback_client and self.fallback_model_name:
print(f" 主 API 内容过滤错误 ({e_primary.status_code})。尝试切换到备用 API ({self.fallback_client.base_url})...")
try:
fallback_completion = await self._attempt_api_call(
client=self.fallback_client,
model_name=self.fallback_model_name,
messages=messages,
max_retries=self.max_retries_fallback,
api_name="备用",
**kwargs.copy()
)
print(f"✅ 备用 API 调用成功。")
return fallback_completion
except APIError as e_fallback:
print(f"❌ 备用 API 调用最终失败: {type(e_fallback).__name__} - {e_fallback}")
raise e_fallback
else:
if not (self.fallback_client and self.fallback_model_name and is_content_filter_error):
# 如果不是内容过滤错误或者没有可用的备用API则记录主API的原始错误
print(f" 主 API 错误 ({type(e_primary).__name__}: {e_primary}), 且不满足备用条件或备用API未配置。")
raise e_primary
except APIError as e_primary_other:
print(f"❌ 主 API 调用最终失败 (非内容过滤,错误类型: {type(e_primary_other).__name__}): {e_primary_other}")
if self.fallback_client and self.fallback_model_name:
print(f" 主 API 失败,尝试切换到备用 API ({self.fallback_client.base_url})...")
try:
fallback_completion = await self._attempt_api_call(
client=self.fallback_client,
model_name=self.fallback_model_name,
messages=messages,
max_retries=self.max_retries_fallback,
api_name="备用",
**kwargs.copy()
)
print(f"✅ 备用 API 调用成功。")
return fallback_completion
except APIError as e_fallback_after_primary_fail:
print(f"❌ 备用 API 在主 API 失败后也调用失败: {type(e_fallback_after_primary_fail).__name__} - {e_fallback_after_primary_fail}")
raise e_fallback_after_primary_fail
else:
raise e_primary_other
async def close(self):
"""异步关闭主客户端和备用客户端 (如果存在)。"""
if not self._closed:
await self.primary_client.close()
if self.fallback_client:
await self.fallback_client.close()
self._closed = True
# print("AsyncFallbackOpenAIClient 已关闭。")
async def __aenter__(self):
if self._closed:
raise RuntimeError("AsyncFallbackOpenAIClient 不能在关闭后重新进入。请创建一个新实例。")
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()