Files
iov_ana/core/utils.py
Jeason b7a27b12bd SQLite 持久连接 — sandbox 不再每次查询开关连接,改为 __init__ 时建连、close() 时释放
Explorer 的 system prompt 明确告知 sandbox 规则 — "每条 SQL 必须包含聚合函数或 LIMIT",减少 LLM 生成违规 SQL 浪费轮次
LLM 客户端单例 — 所有组件共享一个 openai.OpenAI 实例,不再各建各的
sanitize 顺序修复 — 小样本抑制放在 float round 之前,避免被 round 干扰
quick_detect 从 O(n²) 改为 O(n) — 按列聚合一次,加去重,不再对每行重复算整列统计
历史上下文实际生效 — get_context_for 的结果现在会注入到 Explorer 的初始 prompt 里,多轮分析时 LLM 能看到之前的发现
2026-03-20 13:20:31 +08:00

94 lines
2.7 KiB
Python

"""
公共工具 —— JSON 提取、LLM 客户端单例
"""
import json
import re
from typing import Any
import openai
# ── LLM 客户端单例 ──────────────────────────────────
_llm_client: openai.OpenAI | None = None
_llm_model: str = ""
def get_llm_client(config: dict) -> tuple[openai.OpenAI, str]:
"""获取 LLM 客户端(单例),避免每个组件各建一个"""
global _llm_client, _llm_model
if _llm_client is None:
_llm_client = openai.OpenAI(
api_key=config["api_key"],
base_url=config["base_url"],
)
_llm_model = config["model"]
return _llm_client, _llm_model
# ── JSON 提取 ────────────────────────────────────────
def extract_json_object(text: str) -> dict:
"""从 LLM 输出提取 JSON 对象"""
text = _clean_json_text(text)
try:
return json.loads(text)
except json.JSONDecodeError:
pass
for pattern in [r'```json\s*\n(.*?)\n```', r'```\s*\n(.*?)\n```']:
match = re.search(pattern, text, re.DOTALL)
if match:
try:
return json.loads(_clean_json_text(match.group(1)))
except json.JSONDecodeError:
continue
match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', text, re.DOTALL)
if match:
try:
return json.loads(_clean_json_text(match.group()))
except json.JSONDecodeError:
pass
return {}
def extract_json_array(text: str) -> list[dict]:
"""从 LLM 输出提取 JSON 数组(处理尾逗号、注释等)"""
text = _clean_json_text(text)
try:
result = json.loads(text)
if isinstance(result, list):
return result
except json.JSONDecodeError:
pass
for pattern in [r'```json\s*\n(.*?)\n```', r'```\s*\n(.*?)\n```']:
match = re.search(pattern, text, re.DOTALL)
if match:
try:
result = json.loads(_clean_json_text(match.group(1)))
if isinstance(result, list):
return result
except json.JSONDecodeError:
continue
match = re.search(r'\[.*\]', text, re.DOTALL)
if match:
try:
return json.loads(_clean_json_text(match.group()))
except json.JSONDecodeError:
pass
return []
def _clean_json_text(s: str) -> str:
"""清理 LLM 常见的非标准 JSON"""
s = re.sub(r'//.*?\n', '\n', s)
s = re.sub(r'/\*.*?\*/', '', s, flags=re.DOTALL)
s = re.sub(r',\s*([}\]])', r'\1', s)
return s.strip()