Files
iov_data_analysis_agent/utils/llm_helper.py
2026-01-31 18:00:05 +08:00

194 lines
6.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""
LLM调用辅助模块
"""
import asyncio
import yaml
from typing import Optional, Callable, AsyncIterator
from config.llm_config import LLMConfig
from config.app_config import app_config
from utils.fallback_openai_client import AsyncFallbackOpenAIClient
from utils.cache_manager import LLMCacheManager
# 初始化LLM缓存管理器
llm_cache = LLMCacheManager(
cache_dir=app_config.llm_cache_dir,
enabled=app_config.llm_cache_enabled
)
class LLMHelper:
"""LLM调用辅助类支持同步和异步调用"""
def __init__(self, config: LLMConfig = None):
self.config = config
self.client = AsyncFallbackOpenAIClient(
primary_api_key=config.api_key,
primary_base_url=config.base_url,
primary_model_name=config.model
)
async def async_call(self, prompt: str, system_prompt: str = None, max_tokens: int = None, temperature: float = None) -> str:
"""异步调用LLM"""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
kwargs = {}
if max_tokens is not None:
kwargs['max_tokens'] = max_tokens
else:
kwargs['max_tokens'] = self.config.max_tokens
if temperature is not None:
kwargs['temperature'] = temperature
else:
kwargs['temperature'] = self.config.temperature
try:
response = await self.client.chat_completions_create(
messages=messages,
**kwargs
)
return response.choices[0].message.content
except Exception as e:
print(f"LLM调用失败: {e}")
return ""
def call(self, prompt: str, system_prompt: str = None, max_tokens: int = None, temperature: float = None) -> str:
"""同步调用LLM"""
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
import nest_asyncio
nest_asyncio.apply()
return loop.run_until_complete(self.async_call(prompt, system_prompt, max_tokens, temperature))
def parse_yaml_response(self, response: str) -> dict:
"""解析YAML格式的响应"""
try:
# 提取```yaml和```之间的内容
if '```yaml' in response:
start = response.find('```yaml') + 7
end = response.find('```', start)
yaml_content = response[start:end].strip()
elif '```' in response:
start = response.find('```') + 3
end = response.find('```', start)
yaml_content = response[start:end].strip()
else:
yaml_content = response.strip()
parsed = yaml.safe_load(yaml_content)
return parsed if parsed is not None else {}
except Exception as e:
print(f"YAML解析失败: {e}")
print(f"原始响应: {response}")
return {}
async def close(self):
"""关闭客户端"""
await self.client.close()
async def async_call_with_cache(
self,
prompt: str,
system_prompt: str = None,
max_tokens: int = None,
temperature: float = None,
use_cache: bool = True
) -> str:
"""带缓存的异步LLM调用"""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
# 生成缓存键
cache_key = llm_cache.get_cache_key_from_messages(messages, self.config.model)
# 尝试从缓存获取
if use_cache and app_config.llm_cache_enabled:
cached_response = llm_cache.get(cache_key)
if cached_response:
print("[CACHE] 使用LLM缓存响应")
return cached_response
# 调用LLM
response = await self.async_call(prompt, system_prompt, max_tokens, temperature)
# 缓存响应
if use_cache and app_config.llm_cache_enabled and response:
llm_cache.set(cache_key, response)
return response
def call_with_cache(
self,
prompt: str,
system_prompt: str = None,
max_tokens: int = None,
temperature: float = None,
use_cache: bool = True
) -> str:
"""带缓存的同步LLM调用"""
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
import nest_asyncio
nest_asyncio.apply()
return loop.run_until_complete(
self.async_call_with_cache(prompt, system_prompt, max_tokens, temperature, use_cache)
)
async def async_call_stream(
self,
prompt: str,
system_prompt: str = None,
max_tokens: int = None,
temperature: float = None,
callback: Optional[Callable[[str], None]] = None
) -> AsyncIterator[str]:
"""流式异步LLM调用"""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
kwargs = {
'stream': True,
'max_tokens': max_tokens or self.config.max_tokens,
'temperature': temperature or self.config.temperature
}
try:
response = await self.client.chat_completions_create(
messages=messages,
**kwargs
)
full_response = ""
async for chunk in response:
if chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
full_response += content
# 调用回调函数
if callback:
callback(content)
yield content
except Exception as e:
print(f"流式LLM调用失败: {e}")
yield ""