2026-01-06 19:44:17 +08:00
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
|
"""
|
|
|
|
|
|
LLM调用辅助模块
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import asyncio
|
|
|
|
|
|
import yaml
|
2026-01-24 12:52:35 +08:00
|
|
|
|
from typing import Optional, Callable, AsyncIterator
|
2026-01-06 19:44:17 +08:00
|
|
|
|
from config.llm_config import LLMConfig
|
2026-01-24 12:52:35 +08:00
|
|
|
|
from config.app_config import app_config
|
2026-01-06 19:44:17 +08:00
|
|
|
|
from utils.fallback_openai_client import AsyncFallbackOpenAIClient
|
2026-01-24 12:52:35 +08:00
|
|
|
|
from utils.cache_manager import LLMCacheManager
|
|
|
|
|
|
|
|
|
|
|
|
# 初始化LLM缓存管理器
|
|
|
|
|
|
llm_cache = LLMCacheManager(
|
|
|
|
|
|
cache_dir=app_config.llm_cache_dir,
|
|
|
|
|
|
enabled=app_config.llm_cache_enabled
|
|
|
|
|
|
)
|
2026-01-06 19:44:17 +08:00
|
|
|
|
|
|
|
|
|
|
class LLMHelper:
|
|
|
|
|
|
"""LLM调用辅助类,支持同步和异步调用"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, config: LLMConfig = None):
|
|
|
|
|
|
self.config = config
|
|
|
|
|
|
self.client = AsyncFallbackOpenAIClient(
|
|
|
|
|
|
primary_api_key=config.api_key,
|
|
|
|
|
|
primary_base_url=config.base_url,
|
|
|
|
|
|
primary_model_name=config.model
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
async def async_call(self, prompt: str, system_prompt: str = None, max_tokens: int = None, temperature: float = None) -> str:
|
|
|
|
|
|
"""异步调用LLM"""
|
|
|
|
|
|
messages = []
|
|
|
|
|
|
if system_prompt:
|
|
|
|
|
|
messages.append({"role": "system", "content": system_prompt})
|
|
|
|
|
|
messages.append({"role": "user", "content": prompt})
|
|
|
|
|
|
|
|
|
|
|
|
kwargs = {}
|
|
|
|
|
|
if max_tokens is not None:
|
|
|
|
|
|
kwargs['max_tokens'] = max_tokens
|
|
|
|
|
|
else:
|
|
|
|
|
|
kwargs['max_tokens'] = self.config.max_tokens
|
|
|
|
|
|
|
|
|
|
|
|
if temperature is not None:
|
|
|
|
|
|
kwargs['temperature'] = temperature
|
|
|
|
|
|
else:
|
|
|
|
|
|
kwargs['temperature'] = self.config.temperature
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
response = await self.client.chat_completions_create(
|
|
|
|
|
|
messages=messages,
|
|
|
|
|
|
**kwargs
|
|
|
|
|
|
)
|
|
|
|
|
|
return response.choices[0].message.content
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"LLM调用失败: {e}")
|
|
|
|
|
|
return ""
|
|
|
|
|
|
|
|
|
|
|
|
def call(self, prompt: str, system_prompt: str = None, max_tokens: int = None, temperature: float = None) -> str:
|
|
|
|
|
|
"""同步调用LLM"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
|
|
except RuntimeError:
|
|
|
|
|
|
loop = asyncio.new_event_loop()
|
|
|
|
|
|
asyncio.set_event_loop(loop)
|
|
|
|
|
|
|
|
|
|
|
|
import nest_asyncio
|
|
|
|
|
|
nest_asyncio.apply()
|
|
|
|
|
|
|
|
|
|
|
|
return loop.run_until_complete(self.async_call(prompt, system_prompt, max_tokens, temperature))
|
|
|
|
|
|
|
|
|
|
|
|
def parse_yaml_response(self, response: str) -> dict:
|
|
|
|
|
|
"""解析YAML格式的响应"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 提取```yaml和```之间的内容
|
|
|
|
|
|
if '```yaml' in response:
|
|
|
|
|
|
start = response.find('```yaml') + 7
|
|
|
|
|
|
end = response.find('```', start)
|
|
|
|
|
|
yaml_content = response[start:end].strip()
|
|
|
|
|
|
elif '```' in response:
|
|
|
|
|
|
start = response.find('```') + 3
|
|
|
|
|
|
end = response.find('```', start)
|
|
|
|
|
|
yaml_content = response[start:end].strip()
|
|
|
|
|
|
else:
|
|
|
|
|
|
yaml_content = response.strip()
|
|
|
|
|
|
|
2026-01-09 16:52:45 +08:00
|
|
|
|
parsed = yaml.safe_load(yaml_content)
|
|
|
|
|
|
return parsed if parsed is not None else {}
|
2026-01-06 19:44:17 +08:00
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"YAML解析失败: {e}")
|
|
|
|
|
|
print(f"原始响应: {response}")
|
|
|
|
|
|
return {}
|
|
|
|
|
|
|
2026-01-24 12:52:35 +08:00
|
|
|
|
|
2026-01-06 19:44:17 +08:00
|
|
|
|
async def close(self):
|
|
|
|
|
|
"""关闭客户端"""
|
2026-01-24 12:52:35 +08:00
|
|
|
|
await self.client.close()
|
|
|
|
|
|
|
|
|
|
|
|
async def async_call_with_cache(
|
|
|
|
|
|
self,
|
|
|
|
|
|
prompt: str,
|
|
|
|
|
|
system_prompt: str = None,
|
|
|
|
|
|
max_tokens: int = None,
|
|
|
|
|
|
temperature: float = None,
|
|
|
|
|
|
use_cache: bool = True
|
|
|
|
|
|
) -> str:
|
|
|
|
|
|
"""带缓存的异步LLM调用"""
|
|
|
|
|
|
messages = []
|
|
|
|
|
|
if system_prompt:
|
|
|
|
|
|
messages.append({"role": "system", "content": system_prompt})
|
|
|
|
|
|
messages.append({"role": "user", "content": prompt})
|
|
|
|
|
|
|
|
|
|
|
|
# 生成缓存键
|
|
|
|
|
|
cache_key = llm_cache.get_cache_key_from_messages(messages, self.config.model)
|
|
|
|
|
|
|
|
|
|
|
|
# 尝试从缓存获取
|
|
|
|
|
|
if use_cache and app_config.llm_cache_enabled:
|
|
|
|
|
|
cached_response = llm_cache.get(cache_key)
|
|
|
|
|
|
if cached_response:
|
2026-01-31 18:00:05 +08:00
|
|
|
|
print("[CACHE] 使用LLM缓存响应")
|
2026-01-24 12:52:35 +08:00
|
|
|
|
return cached_response
|
|
|
|
|
|
|
|
|
|
|
|
# 调用LLM
|
|
|
|
|
|
response = await self.async_call(prompt, system_prompt, max_tokens, temperature)
|
|
|
|
|
|
|
|
|
|
|
|
# 缓存响应
|
|
|
|
|
|
if use_cache and app_config.llm_cache_enabled and response:
|
|
|
|
|
|
llm_cache.set(cache_key, response)
|
|
|
|
|
|
|
|
|
|
|
|
return response
|
|
|
|
|
|
|
|
|
|
|
|
def call_with_cache(
|
|
|
|
|
|
self,
|
|
|
|
|
|
prompt: str,
|
|
|
|
|
|
system_prompt: str = None,
|
|
|
|
|
|
max_tokens: int = None,
|
|
|
|
|
|
temperature: float = None,
|
|
|
|
|
|
use_cache: bool = True
|
|
|
|
|
|
) -> str:
|
|
|
|
|
|
"""带缓存的同步LLM调用"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
|
|
except RuntimeError:
|
|
|
|
|
|
loop = asyncio.new_event_loop()
|
|
|
|
|
|
asyncio.set_event_loop(loop)
|
|
|
|
|
|
|
|
|
|
|
|
import nest_asyncio
|
|
|
|
|
|
nest_asyncio.apply()
|
|
|
|
|
|
|
|
|
|
|
|
return loop.run_until_complete(
|
|
|
|
|
|
self.async_call_with_cache(prompt, system_prompt, max_tokens, temperature, use_cache)
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
async def async_call_stream(
|
|
|
|
|
|
self,
|
|
|
|
|
|
prompt: str,
|
|
|
|
|
|
system_prompt: str = None,
|
|
|
|
|
|
max_tokens: int = None,
|
|
|
|
|
|
temperature: float = None,
|
|
|
|
|
|
callback: Optional[Callable[[str], None]] = None
|
|
|
|
|
|
) -> AsyncIterator[str]:
|
|
|
|
|
|
"""流式异步LLM调用"""
|
|
|
|
|
|
messages = []
|
|
|
|
|
|
if system_prompt:
|
|
|
|
|
|
messages.append({"role": "system", "content": system_prompt})
|
|
|
|
|
|
messages.append({"role": "user", "content": prompt})
|
|
|
|
|
|
|
|
|
|
|
|
kwargs = {
|
|
|
|
|
|
'stream': True,
|
|
|
|
|
|
'max_tokens': max_tokens or self.config.max_tokens,
|
|
|
|
|
|
'temperature': temperature or self.config.temperature
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
response = await self.client.chat_completions_create(
|
|
|
|
|
|
messages=messages,
|
|
|
|
|
|
**kwargs
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
full_response = ""
|
|
|
|
|
|
async for chunk in response:
|
|
|
|
|
|
if chunk.choices[0].delta.content:
|
|
|
|
|
|
content = chunk.choices[0].delta.content
|
|
|
|
|
|
full_response += content
|
|
|
|
|
|
|
|
|
|
|
|
# 调用回调函数
|
|
|
|
|
|
if callback:
|
|
|
|
|
|
callback(content)
|
|
|
|
|
|
|
|
|
|
|
|
yield content
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"流式LLM调用失败: {e}")
|
|
|
|
|
|
yield ""
|