2026-03-20 13:20:31 +08:00
|
|
|
"""
|
2026-03-31 14:39:17 +08:00
|
|
|
公共工具 —— JSON 提取、LLM 客户端单例、重试机制
|
2026-03-20 13:20:31 +08:00
|
|
|
"""
|
|
|
|
|
import json
|
|
|
|
|
import re
|
2026-03-31 14:39:17 +08:00
|
|
|
import time
|
2026-03-20 13:20:31 +08:00
|
|
|
from typing import Any
|
|
|
|
|
|
|
|
|
|
import openai
|
|
|
|
|
|
|
|
|
|
# ── LLM 客户端单例 ──────────────────────────────────
|
|
|
|
|
|
|
|
|
|
_llm_client: openai.OpenAI | None = None
|
|
|
|
|
_llm_model: str = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_llm_client(config: dict) -> tuple[openai.OpenAI, str]:
|
|
|
|
|
"""获取 LLM 客户端(单例),避免每个组件各建一个"""
|
|
|
|
|
global _llm_client, _llm_model
|
|
|
|
|
if _llm_client is None:
|
2026-03-31 14:39:17 +08:00
|
|
|
api_key = config.get("api_key", "")
|
|
|
|
|
if not api_key:
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
"LLM_API_KEY 未配置!请设置环境变量或在 .env 文件中添加:\n"
|
|
|
|
|
" LLM_API_KEY=your-key\n"
|
|
|
|
|
" LLM_BASE_URL=https://api.openai.com/v1\n"
|
|
|
|
|
" LLM_MODEL=gpt-4o-mini"
|
|
|
|
|
)
|
2026-03-20 13:20:31 +08:00
|
|
|
_llm_client = openai.OpenAI(
|
2026-03-31 14:39:17 +08:00
|
|
|
api_key=api_key,
|
2026-03-20 13:20:31 +08:00
|
|
|
base_url=config["base_url"],
|
|
|
|
|
)
|
|
|
|
|
_llm_model = config["model"]
|
|
|
|
|
return _llm_client, _llm_model
|
|
|
|
|
|
|
|
|
|
|
2026-03-31 14:39:17 +08:00
|
|
|
# ── LLM 调用重试包装 ────────────────────────────────
|
|
|
|
|
|
|
|
|
|
class LLMCallError(Exception):
|
|
|
|
|
"""LLM 调用最终失败"""
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def llm_chat(client: openai.OpenAI, model: str, messages: list[dict],
|
|
|
|
|
max_retries: int = 3, **kwargs) -> str:
|
|
|
|
|
"""
|
|
|
|
|
带指数退避重试的 LLM 调用。
|
|
|
|
|
处理 429 限频、5xx 超时、网络错误。
|
|
|
|
|
"""
|
|
|
|
|
last_err = None
|
|
|
|
|
for attempt in range(max_retries):
|
|
|
|
|
try:
|
|
|
|
|
response = client.chat.completions.create(
|
|
|
|
|
model=model, messages=messages, **kwargs
|
|
|
|
|
)
|
|
|
|
|
return response.choices[0].message.content.strip()
|
|
|
|
|
except openai.RateLimitError as e:
|
|
|
|
|
last_err = e
|
|
|
|
|
# 读取 Retry-After 或使用默认退避
|
|
|
|
|
wait = _get_retry_delay(e, attempt)
|
|
|
|
|
print(f" ⏳ 限频,等待 {wait:.1f}s 后重试 ({attempt+1}/{max_retries})...")
|
|
|
|
|
time.sleep(wait)
|
|
|
|
|
except (openai.APITimeoutError, openai.APIConnectionError, openai.APIStatusError) as e:
|
|
|
|
|
last_err = e
|
|
|
|
|
wait = min(2 ** attempt * 2, 30)
|
|
|
|
|
print(f" ⚠️ API 错误: {type(e).__name__},等待 {wait:.1f}s ({attempt+1}/{max_retries})...")
|
|
|
|
|
time.sleep(wait)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
last_err = e
|
|
|
|
|
if attempt < max_retries - 1:
|
|
|
|
|
wait = min(2 ** attempt * 2, 30)
|
|
|
|
|
time.sleep(wait)
|
|
|
|
|
continue
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
raise LLMCallError(f"LLM 调用失败({max_retries} 次重试): {last_err}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_retry_delay(error, attempt: int) -> float:
|
|
|
|
|
"""从错误响应中提取重试等待时间"""
|
|
|
|
|
try:
|
|
|
|
|
if hasattr(error, 'response') and error.response is not None:
|
|
|
|
|
retry_after = error.response.headers.get('Retry-After')
|
|
|
|
|
if retry_after:
|
|
|
|
|
return float(retry_after)
|
|
|
|
|
except Exception:
|
|
|
|
|
pass
|
|
|
|
|
# 指数退避: 2s, 4s, 8s, 最大 30s
|
|
|
|
|
return min(2 ** (attempt + 1), 30)
|
|
|
|
|
|
|
|
|
|
|
2026-03-20 13:20:31 +08:00
|
|
|
# ── JSON 提取 ────────────────────────────────────────
|
|
|
|
|
|
|
|
|
|
def extract_json_object(text: str) -> dict:
|
|
|
|
|
"""从 LLM 输出提取 JSON 对象"""
|
|
|
|
|
text = _clean_json_text(text)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
return json.loads(text)
|
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
for pattern in [r'```json\s*\n(.*?)\n```', r'```\s*\n(.*?)\n```']:
|
|
|
|
|
match = re.search(pattern, text, re.DOTALL)
|
|
|
|
|
if match:
|
|
|
|
|
try:
|
|
|
|
|
return json.loads(_clean_json_text(match.group(1)))
|
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', text, re.DOTALL)
|
|
|
|
|
if match:
|
|
|
|
|
try:
|
|
|
|
|
return json.loads(_clean_json_text(match.group()))
|
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_json_array(text: str) -> list[dict]:
|
|
|
|
|
"""从 LLM 输出提取 JSON 数组(处理尾逗号、注释等)"""
|
|
|
|
|
text = _clean_json_text(text)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
result = json.loads(text)
|
|
|
|
|
if isinstance(result, list):
|
|
|
|
|
return result
|
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
for pattern in [r'```json\s*\n(.*?)\n```', r'```\s*\n(.*?)\n```']:
|
|
|
|
|
match = re.search(pattern, text, re.DOTALL)
|
|
|
|
|
if match:
|
|
|
|
|
try:
|
|
|
|
|
result = json.loads(_clean_json_text(match.group(1)))
|
|
|
|
|
if isinstance(result, list):
|
|
|
|
|
return result
|
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
match = re.search(r'\[.*\]', text, re.DOTALL)
|
|
|
|
|
if match:
|
|
|
|
|
try:
|
|
|
|
|
return json.loads(_clean_json_text(match.group()))
|
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _clean_json_text(s: str) -> str:
|
|
|
|
|
"""清理 LLM 常见的非标准 JSON"""
|
|
|
|
|
s = re.sub(r'//.*?\n', '\n', s)
|
|
|
|
|
s = re.sub(r'/\*.*?\*/', '', s, flags=re.DOTALL)
|
|
|
|
|
s = re.sub(r',\s*([}\]])', r'\1', s)
|
|
|
|
|
return s.strip()
|