""" 公共工具 —— JSON 提取、LLM 客户端单例 """ import json import re from typing import Any import openai # ── LLM 客户端单例 ────────────────────────────────── _llm_client: openai.OpenAI | None = None _llm_model: str = "" def get_llm_client(config: dict) -> tuple[openai.OpenAI, str]: """获取 LLM 客户端(单例),避免每个组件各建一个""" global _llm_client, _llm_model if _llm_client is None: _llm_client = openai.OpenAI( api_key=config["api_key"], base_url=config["base_url"], ) _llm_model = config["model"] return _llm_client, _llm_model # ── JSON 提取 ──────────────────────────────────────── def extract_json_object(text: str) -> dict: """从 LLM 输出提取 JSON 对象""" text = _clean_json_text(text) try: return json.loads(text) except json.JSONDecodeError: pass for pattern in [r'```json\s*\n(.*?)\n```', r'```\s*\n(.*?)\n```']: match = re.search(pattern, text, re.DOTALL) if match: try: return json.loads(_clean_json_text(match.group(1))) except json.JSONDecodeError: continue match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', text, re.DOTALL) if match: try: return json.loads(_clean_json_text(match.group())) except json.JSONDecodeError: pass return {} def extract_json_array(text: str) -> list[dict]: """从 LLM 输出提取 JSON 数组(处理尾逗号、注释等)""" text = _clean_json_text(text) try: result = json.loads(text) if isinstance(result, list): return result except json.JSONDecodeError: pass for pattern in [r'```json\s*\n(.*?)\n```', r'```\s*\n(.*?)\n```']: match = re.search(pattern, text, re.DOTALL) if match: try: result = json.loads(_clean_json_text(match.group(1))) if isinstance(result, list): return result except json.JSONDecodeError: continue match = re.search(r'\[.*\]', text, re.DOTALL) if match: try: return json.loads(_clean_json_text(match.group())) except json.JSONDecodeError: pass return [] def _clean_json_text(s: str) -> str: """清理 LLM 常见的非标准 JSON""" s = re.sub(r'//.*?\n', '\n', s) s = re.sub(r'/\*.*?\*/', '', s, flags=re.DOTALL) s = re.sub(r',\s*([}\]])', r'\1', s) return s.strip()