Files
iov_data_analysis_agent/utils/cache_manager.py

104 lines
3.2 KiB
Python
Raw Normal View History

# -*- coding: utf-8 -*-
"""
缓存管理器 - 支持数据和LLM响应缓存
"""
import os
import json
import hashlib
import pickle
from pathlib import Path
from typing import Any, Optional, Callable
from functools import wraps
class CacheManager:
"""缓存管理器"""
def __init__(self, cache_dir: str = ".cache", enabled: bool = True):
self.cache_dir = Path(cache_dir)
self.enabled = enabled
if self.enabled:
self.cache_dir.mkdir(parents=True, exist_ok=True)
def _get_cache_key(self, *args, **kwargs) -> str:
"""生成缓存键"""
key_data = f"{args}_{kwargs}"
return hashlib.md5(key_data.encode()).hexdigest()
def _get_cache_path(self, key: str) -> Path:
"""获取缓存文件路径"""
return self.cache_dir / f"{key}.pkl"
def get(self, key: str) -> Optional[Any]:
"""获取缓存"""
if not self.enabled:
return None
cache_path = self._get_cache_path(key)
if cache_path.exists():
try:
with open(cache_path, 'rb') as f:
return pickle.load(f)
except Exception as e:
print(f"⚠️ 读取缓存失败: {e}")
return None
return None
def set(self, key: str, value: Any) -> None:
"""设置缓存"""
if not self.enabled:
return
cache_path = self._get_cache_path(key)
try:
with open(cache_path, 'wb') as f:
pickle.dump(value, f)
except Exception as e:
print(f"⚠️ 写入缓存失败: {e}")
def clear(self) -> None:
"""清空所有缓存"""
if self.cache_dir.exists():
for cache_file in self.cache_dir.glob("*.pkl"):
cache_file.unlink()
print("✅ 缓存已清空")
def cached(self, key_func: Optional[Callable] = None):
"""缓存装饰器"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
if not self.enabled:
return func(*args, **kwargs)
# 生成缓存键
if key_func:
cache_key = key_func(*args, **kwargs)
else:
cache_key = self._get_cache_key(*args, **kwargs)
# 尝试从缓存获取
cached_value = self.get(cache_key)
if cached_value is not None:
print(f"💾 使用缓存: {cache_key[:8]}...")
return cached_value
# 执行函数并缓存结果
result = func(*args, **kwargs)
self.set(cache_key, result)
return result
return wrapper
return decorator
class LLMCacheManager(CacheManager):
"""LLM响应缓存管理器"""
def get_cache_key_from_messages(self, messages: list, model: str = "") -> str:
"""从消息列表生成缓存键"""
key_data = json.dumps(messages, sort_keys=True) + model
return hashlib.md5(key_data.encode()).hexdigest()