""" 公共工具 —— JSON 提取、LLM 客户端单例、重试机制 """ import json import re import time from typing import Any import openai # ── LLM 客户端单例 ────────────────────────────────── _llm_client: openai.OpenAI | None = None _llm_model: str = "" def get_llm_client(config: dict) -> tuple[openai.OpenAI, str]: """获取 LLM 客户端(单例),避免每个组件各建一个""" global _llm_client, _llm_model if _llm_client is None: api_key = config.get("api_key", "") if not api_key: raise RuntimeError( "LLM_API_KEY 未配置!请设置环境变量或在 .env 文件中添加:\n" " LLM_API_KEY=your-key\n" " LLM_BASE_URL=https://api.openai.com/v1\n" " LLM_MODEL=gpt-4o-mini" ) _llm_client = openai.OpenAI( api_key=api_key, base_url=config["base_url"], ) _llm_model = config["model"] return _llm_client, _llm_model # ── LLM 调用重试包装 ──────────────────────────────── class LLMCallError(Exception): """LLM 调用最终失败""" pass def llm_chat(client: openai.OpenAI, model: str, messages: list[dict], max_retries: int = 3, **kwargs) -> str: """ 带指数退避重试的 LLM 调用。 处理 429 限频、5xx 超时、网络错误。 """ last_err = None for attempt in range(max_retries): try: response = client.chat.completions.create( model=model, messages=messages, **kwargs ) return response.choices[0].message.content.strip() except openai.RateLimitError as e: last_err = e # 读取 Retry-After 或使用默认退避 wait = _get_retry_delay(e, attempt) print(f" ⏳ 限频,等待 {wait:.1f}s 后重试 ({attempt+1}/{max_retries})...") time.sleep(wait) except (openai.APITimeoutError, openai.APIConnectionError, openai.APIStatusError) as e: last_err = e wait = min(2 ** attempt * 2, 30) print(f" ⚠️ API 错误: {type(e).__name__},等待 {wait:.1f}s ({attempt+1}/{max_retries})...") time.sleep(wait) except Exception as e: last_err = e if attempt < max_retries - 1: wait = min(2 ** attempt * 2, 30) time.sleep(wait) continue raise raise LLMCallError(f"LLM 调用失败({max_retries} 次重试): {last_err}") def _get_retry_delay(error, attempt: int) -> float: """从错误响应中提取重试等待时间""" try: if hasattr(error, 'response') and error.response is not None: retry_after = error.response.headers.get('Retry-After') if retry_after: return float(retry_after) except Exception: pass # 指数退避: 2s, 4s, 8s, 最大 30s return min(2 ** (attempt + 1), 30) # ── JSON 提取 ──────────────────────────────────────── def extract_json_object(text: str) -> dict: """从 LLM 输出提取 JSON 对象""" text = _clean_json_text(text) try: return json.loads(text) except json.JSONDecodeError: pass for pattern in [r'```json\s*\n(.*?)\n```', r'```\s*\n(.*?)\n```']: match = re.search(pattern, text, re.DOTALL) if match: try: return json.loads(_clean_json_text(match.group(1))) except json.JSONDecodeError: continue match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', text, re.DOTALL) if match: try: return json.loads(_clean_json_text(match.group())) except json.JSONDecodeError: pass return {} def extract_json_array(text: str) -> list[dict]: """从 LLM 输出提取 JSON 数组(处理尾逗号、注释等)""" text = _clean_json_text(text) try: result = json.loads(text) if isinstance(result, list): return result except json.JSONDecodeError: pass for pattern in [r'```json\s*\n(.*?)\n```', r'```\s*\n(.*?)\n```']: match = re.search(pattern, text, re.DOTALL) if match: try: result = json.loads(_clean_json_text(match.group(1))) if isinstance(result, list): return result except json.JSONDecodeError: continue match = re.search(r'\[.*\]', text, re.DOTALL) if match: try: return json.loads(_clean_json_text(match.group())) except json.JSONDecodeError: pass return [] def _clean_json_text(s: str) -> str: """清理 LLM 常见的非标准 JSON""" s = re.sub(r'//.*?\n', '\n', s) s = re.sub(r'/\*.*?\*/', '', s, flags=re.DOTALL) s = re.sub(r',\s*([}\]])', r'\1', s) return s.strip()