104 lines
3.2 KiB
Python
104 lines
3.2 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
缓存管理器 - 支持数据和LLM响应缓存
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import hashlib
|
|
import pickle
|
|
from pathlib import Path
|
|
from typing import Any, Optional, Callable
|
|
from functools import wraps
|
|
|
|
|
|
class CacheManager:
|
|
"""缓存管理器"""
|
|
|
|
def __init__(self, cache_dir: str = ".cache", enabled: bool = True):
|
|
self.cache_dir = Path(cache_dir)
|
|
self.enabled = enabled
|
|
|
|
if self.enabled:
|
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
def _get_cache_key(self, *args, **kwargs) -> str:
|
|
"""生成缓存键"""
|
|
key_data = f"{args}_{kwargs}"
|
|
return hashlib.md5(key_data.encode()).hexdigest()
|
|
|
|
def _get_cache_path(self, key: str) -> Path:
|
|
"""获取缓存文件路径"""
|
|
return self.cache_dir / f"{key}.pkl"
|
|
|
|
def get(self, key: str) -> Optional[Any]:
|
|
"""获取缓存"""
|
|
if not self.enabled:
|
|
return None
|
|
|
|
cache_path = self._get_cache_path(key)
|
|
if cache_path.exists():
|
|
try:
|
|
with open(cache_path, 'rb') as f:
|
|
return pickle.load(f)
|
|
except Exception as e:
|
|
print(f"⚠️ 读取缓存失败: {e}")
|
|
return None
|
|
return None
|
|
|
|
def set(self, key: str, value: Any) -> None:
|
|
"""设置缓存"""
|
|
if not self.enabled:
|
|
return
|
|
|
|
cache_path = self._get_cache_path(key)
|
|
try:
|
|
with open(cache_path, 'wb') as f:
|
|
pickle.dump(value, f)
|
|
except Exception as e:
|
|
print(f"⚠️ 写入缓存失败: {e}")
|
|
|
|
def clear(self) -> None:
|
|
"""清空所有缓存"""
|
|
if self.cache_dir.exists():
|
|
for cache_file in self.cache_dir.glob("*.pkl"):
|
|
cache_file.unlink()
|
|
print("✅ 缓存已清空")
|
|
|
|
def cached(self, key_func: Optional[Callable] = None):
|
|
"""缓存装饰器"""
|
|
def decorator(func):
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
if not self.enabled:
|
|
return func(*args, **kwargs)
|
|
|
|
# 生成缓存键
|
|
if key_func:
|
|
cache_key = key_func(*args, **kwargs)
|
|
else:
|
|
cache_key = self._get_cache_key(*args, **kwargs)
|
|
|
|
# 尝试从缓存获取
|
|
cached_value = self.get(cache_key)
|
|
if cached_value is not None:
|
|
print(f"💾 使用缓存: {cache_key[:8]}...")
|
|
return cached_value
|
|
|
|
# 执行函数并缓存结果
|
|
result = func(*args, **kwargs)
|
|
self.set(cache_key, result)
|
|
return result
|
|
|
|
return wrapper
|
|
return decorator
|
|
|
|
|
|
class LLMCacheManager(CacheManager):
|
|
"""LLM响应缓存管理器"""
|
|
|
|
def get_cache_key_from_messages(self, messages: list, model: str = "") -> str:
|
|
"""从消息列表生成缓存键"""
|
|
key_data = json.dumps(messages, sort_keys=True) + model
|
|
return hashlib.md5(key_data.encode()).hexdigest()
|