Files
assist/src/core/embedding_client.py
2026-03-20 16:50:26 +08:00

153 lines
5.0 KiB
Python

# -*- coding: utf-8 -*-
"""
Embedding 向量客户端(本地模型方案)
使用 sentence-transformers 在本地运行轻量级中文 embedding 模型
零 API 调用、零成本、低延迟
"""
import logging
import hashlib
import threading
from typing import List, Optional
from src.config.unified_config import get_config
from src.core.cache_manager import cache_manager
logger = logging.getLogger(__name__)
class EmbeddingClient:
"""本地 Embedding 向量客户端"""
def __init__(self):
config = get_config()
self.enabled = config.embedding.enabled
self.model_name = config.embedding.model
self.dimension = config.embedding.dimension
self.cache_ttl = config.embedding.cache_ttl
self._model = None
self._lock = threading.Lock()
if self.enabled:
logger.info(f"Embedding 客户端初始化: model={self.model_name} (本地模式)")
else:
logger.info("Embedding 功能已禁用,将使用关键词匹配降级")
def _get_model(self):
"""延迟加载模型(首次调用时下载并加载)"""
if self._model is not None:
return self._model
with self._lock:
if self._model is not None:
return self._model
try:
import os
# 设置 HuggingFace 镜像,解决国内下载问题
os.environ.setdefault("HF_ENDPOINT", "https://hf-mirror.com")
from sentence_transformers import SentenceTransformer
logger.info(f"正在加载 embedding 模型: {self.model_name} ...")
self._model = SentenceTransformer(self.model_name)
logger.info(f"Embedding 模型加载完成: {self.model_name}")
return self._model
except ImportError:
logger.error(
"sentence-transformers 未安装,请运行: pip install sentence-transformers"
)
self.enabled = False
return None
except Exception as e:
logger.error(f"加载 embedding 模型失败: {e}")
self.enabled = False
return None
# ------------------------------------------------------------------
# 公开接口
# ------------------------------------------------------------------
def embed_text(self, text: str) -> Optional[List[float]]:
"""对单条文本生成 embedding 向量,优先从缓存读取"""
if not self.enabled or not text.strip():
return None
cache_key = self._cache_key(text)
cached = cache_manager.get(cache_key)
if cached is not None:
return cached
model = self._get_model()
if model is None:
return None
try:
vec = model.encode(text, normalize_embeddings=True).tolist()
cache_manager.set(cache_key, vec, self.cache_ttl)
return vec
except Exception as e:
logger.error(f"Embedding 生成失败: {e}")
return None
def embed_batch(self, texts: List[str]) -> List[Optional[List[float]]]:
"""批量生成 embedding"""
if not self.enabled:
return [None] * len(texts)
results: List[Optional[List[float]]] = [None] * len(texts)
uncached_indices = []
uncached_texts = []
# 1. 先查缓存
for i, t in enumerate(texts):
if not t.strip():
continue
cached = cache_manager.get(self._cache_key(t))
if cached is not None:
results[i] = cached
else:
uncached_indices.append(i)
uncached_texts.append(t)
if not uncached_texts:
return results
# 2. 批量推理
model = self._get_model()
if model is None:
return results
try:
vectors = model.encode(
uncached_texts, normalize_embeddings=True, batch_size=32
).tolist()
for j, vec in enumerate(vectors):
idx = uncached_indices[j]
results[idx] = vec
cache_manager.set(self._cache_key(uncached_texts[j]), vec, self.cache_ttl)
except Exception as e:
logger.error(f"批量 embedding 生成失败: {e}")
return results
def test_connection(self) -> bool:
"""测试模型是否可用"""
try:
vec = self.embed_text("测试连接")
return vec is not None and len(vec) > 0
except Exception as e:
logger.error(f"Embedding 模型测试失败: {e}")
return False
# ------------------------------------------------------------------
# 内部方法
# ------------------------------------------------------------------
@staticmethod
def _cache_key(text: str) -> str:
"""生成缓存键(基于文本哈希)"""
h = hashlib.md5(text.encode("utf-8")).hexdigest()
return f"emb:{h}"