153 lines
5.0 KiB
Python
153 lines
5.0 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
Embedding 向量客户端(本地模型方案)
|
|
使用 sentence-transformers 在本地运行轻量级中文 embedding 模型
|
|
零 API 调用、零成本、低延迟
|
|
"""
|
|
|
|
import logging
|
|
import hashlib
|
|
import threading
|
|
from typing import List, Optional
|
|
|
|
from src.config.unified_config import get_config
|
|
from src.core.cache_manager import cache_manager
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class EmbeddingClient:
|
|
"""本地 Embedding 向量客户端"""
|
|
|
|
def __init__(self):
|
|
config = get_config()
|
|
self.enabled = config.embedding.enabled
|
|
self.model_name = config.embedding.model
|
|
self.dimension = config.embedding.dimension
|
|
self.cache_ttl = config.embedding.cache_ttl
|
|
|
|
self._model = None
|
|
self._lock = threading.Lock()
|
|
|
|
if self.enabled:
|
|
logger.info(f"Embedding 客户端初始化: model={self.model_name} (本地模式)")
|
|
else:
|
|
logger.info("Embedding 功能已禁用,将使用关键词匹配降级")
|
|
|
|
def _get_model(self):
|
|
"""延迟加载模型(首次调用时下载并加载)"""
|
|
if self._model is not None:
|
|
return self._model
|
|
|
|
with self._lock:
|
|
if self._model is not None:
|
|
return self._model
|
|
|
|
try:
|
|
import os
|
|
# 设置 HuggingFace 镜像,解决国内下载问题
|
|
os.environ.setdefault("HF_ENDPOINT", "https://hf-mirror.com")
|
|
|
|
from sentence_transformers import SentenceTransformer
|
|
logger.info(f"正在加载 embedding 模型: {self.model_name} ...")
|
|
self._model = SentenceTransformer(self.model_name)
|
|
logger.info(f"Embedding 模型加载完成: {self.model_name}")
|
|
return self._model
|
|
except ImportError:
|
|
logger.error(
|
|
"sentence-transformers 未安装,请运行: pip install sentence-transformers"
|
|
)
|
|
self.enabled = False
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"加载 embedding 模型失败: {e}")
|
|
self.enabled = False
|
|
return None
|
|
|
|
# ------------------------------------------------------------------
|
|
# 公开接口
|
|
# ------------------------------------------------------------------
|
|
|
|
def embed_text(self, text: str) -> Optional[List[float]]:
|
|
"""对单条文本生成 embedding 向量,优先从缓存读取"""
|
|
if not self.enabled or not text.strip():
|
|
return None
|
|
|
|
cache_key = self._cache_key(text)
|
|
cached = cache_manager.get(cache_key)
|
|
if cached is not None:
|
|
return cached
|
|
|
|
model = self._get_model()
|
|
if model is None:
|
|
return None
|
|
|
|
try:
|
|
vec = model.encode(text, normalize_embeddings=True).tolist()
|
|
cache_manager.set(cache_key, vec, self.cache_ttl)
|
|
return vec
|
|
except Exception as e:
|
|
logger.error(f"Embedding 生成失败: {e}")
|
|
return None
|
|
|
|
def embed_batch(self, texts: List[str]) -> List[Optional[List[float]]]:
|
|
"""批量生成 embedding"""
|
|
if not self.enabled:
|
|
return [None] * len(texts)
|
|
|
|
results: List[Optional[List[float]]] = [None] * len(texts)
|
|
uncached_indices = []
|
|
uncached_texts = []
|
|
|
|
# 1. 先查缓存
|
|
for i, t in enumerate(texts):
|
|
if not t.strip():
|
|
continue
|
|
cached = cache_manager.get(self._cache_key(t))
|
|
if cached is not None:
|
|
results[i] = cached
|
|
else:
|
|
uncached_indices.append(i)
|
|
uncached_texts.append(t)
|
|
|
|
if not uncached_texts:
|
|
return results
|
|
|
|
# 2. 批量推理
|
|
model = self._get_model()
|
|
if model is None:
|
|
return results
|
|
|
|
try:
|
|
vectors = model.encode(
|
|
uncached_texts, normalize_embeddings=True, batch_size=32
|
|
).tolist()
|
|
|
|
for j, vec in enumerate(vectors):
|
|
idx = uncached_indices[j]
|
|
results[idx] = vec
|
|
cache_manager.set(self._cache_key(uncached_texts[j]), vec, self.cache_ttl)
|
|
except Exception as e:
|
|
logger.error(f"批量 embedding 生成失败: {e}")
|
|
|
|
return results
|
|
|
|
def test_connection(self) -> bool:
|
|
"""测试模型是否可用"""
|
|
try:
|
|
vec = self.embed_text("测试连接")
|
|
return vec is not None and len(vec) > 0
|
|
except Exception as e:
|
|
logger.error(f"Embedding 模型测试失败: {e}")
|
|
return False
|
|
|
|
# ------------------------------------------------------------------
|
|
# 内部方法
|
|
# ------------------------------------------------------------------
|
|
|
|
@staticmethod
|
|
def _cache_key(text: str) -> str:
|
|
"""生成缓存键(基于文本哈希)"""
|
|
h = hashlib.md5(text.encode("utf-8")).hexdigest()
|
|
return f"emb:{h}"
|