feat: 娣诲姞澶氫釜鏂板姛鑳藉拰淇 - 鍖呮嫭鐢ㄦ埛绠$悊銆佹暟鎹簱杩佺Щ銆丟it鎺ㄩ€佸伐鍏风瓑

This commit is contained in:
赵杰 Jie Zhao (雄狮汽车科技)
2025-11-05 10:16:34 +08:00
parent a4261ef06f
commit c9d5c80f42
43 changed files with 4435 additions and 7439 deletions

View File

@@ -3,43 +3,59 @@
"""
语义相似度计算服务
使用sentence-transformers进行更准确的语义相似度计算
使用LLM API进行更准确的语义相似度计算提高理解力并节约服务端资源
"""
import logging
import numpy as np
import re
from typing import List, Tuple, Optional
from sentence_transformers import SentenceTransformer
import torch
logger = logging.getLogger(__name__)
class SemanticSimilarityCalculator:
"""语义相似度计算器"""
"""语义相似度计算器 - 使用LLM API"""
def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
def __init__(self, use_llm: bool = True):
"""
初始化语义相似度计算器
Args:
model_name: 使用的预训练模型名称
- all-MiniLM-L6-v2: 英文模型,速度快,推荐用于生产环境
- paraphrase-multilingual-MiniLM-L12-v2: 多语言模型,支持中文
- paraphrase-multilingual-mpnet-base-v2: 多语言模型,精度高
use_llm: 是否使用LLM API计算相似度默认True推荐
- True: 使用LLM API理解力更强无需加载本地模型
- False: 使用本地模型需要下载HuggingFace模型
"""
self.model_name = model_name
self.use_llm = use_llm
self.model = None
self._load_model()
self.llm_client = None
if use_llm:
self._init_llm_client()
else:
self._load_model()
def _init_llm_client(self):
"""初始化LLM客户端"""
try:
from ..core.llm_client import QwenClient
self.llm_client = QwenClient()
logger.info("LLM客户端初始化成功将使用LLM API计算语义相似度")
except Exception as e:
logger.error(f"初始化LLM客户端失败: {e}")
self.llm_client = None
# 回退到本地模型
self.use_llm = False
self._load_model()
def _load_model(self):
"""加载预训练模型"""
"""加载预训练模型仅在use_llm=False时使用"""
try:
logger.info(f"正在加载语义相似度模型: {self.model_name}")
self.model = SentenceTransformer(self.model_name)
logger.info("语义相似度模型加载成功")
logger.info(f"正在加载本地语义相似度模型: all-MiniLM-L6-v2")
self.model = SentenceTransformer("all-MiniLM-L6-v2")
logger.info("本地语义相似度模型加载成功")
except Exception as e:
logger.error(f"加载语义相似度模型失败: {e}")
# 回退到简单模型
logger.error(f"加载本地语义相似度模型失败: {e}")
self.model = None
def calculate_similarity(self, text1: str, text2: str, fast_mode: bool = True) -> float:
@@ -49,7 +65,7 @@ class SemanticSimilarityCalculator:
Args:
text1: 第一个文本
text2: 第二个文本
fast_mode: 是否使用快速模式(结合传统方法
fast_mode: 是否使用快速模式(仅在使用本地模型时有效
Returns:
相似度分数 (0-1之间)
@@ -58,27 +74,22 @@ class SemanticSimilarityCalculator:
return 0.0
try:
# 快速模式:先使用传统方法快速筛选
if fast_mode:
tfidf_sim = self._calculate_tfidf_similarity(text1, text2)
# 如果传统方法相似度很高或很低,直接返回
if tfidf_sim >= 0.9:
return tfidf_sim
elif tfidf_sim <= 0.3:
return tfidf_sim
# 中等相似度时,使用语义方法进行精确计算
if self.model is not None:
# 优先使用LLM API计算相似度
if self.use_llm and self.llm_client:
return self._calculate_llm_similarity(text1, text2)
# 回退到本地模型或TF-IDF
if self.model is not None:
if fast_mode:
# 快速模式先使用TF-IDF快速筛选
tfidf_sim = self._calculate_tfidf_similarity(text1, text2)
if tfidf_sim >= 0.9 or tfidf_sim <= 0.3:
return tfidf_sim
# 中等相似度时,使用语义方法进行精确计算
semantic_sim = self._calculate_semantic_similarity(text1, text2)
# 结合两种方法的结果
return (tfidf_sim * 0.3 + semantic_sim * 0.7)
else:
return tfidf_sim
# 完整模式:直接使用语义相似度
if self.model is not None:
return self._calculate_semantic_similarity(text1, text2)
return self._calculate_semantic_similarity(text1, text2)
else:
return self._calculate_tfidf_similarity(text1, text2)
@@ -86,6 +97,80 @@ class SemanticSimilarityCalculator:
logger.error(f"计算语义相似度失败: {e}")
return self._calculate_tfidf_similarity(text1, text2)
def _calculate_llm_similarity(self, text1: str, text2: str) -> float:
"""使用LLM API计算语义相似度"""
try:
# 构建prompt让LLM比较两个文本的相似度
prompt = f"""请比较以下两个文本的语义相似度并给出0-1之间的分数保留2位小数其中
- 1.0 表示完全相同
- 0.8-0.9 表示非常相似
- 0.6-0.7 表示较为相似
- 0.4-0.5 表示部分相似
- 0.0-0.3 表示差异很大
文本1: {text1}
文本2: {text2}
请只返回0-1之间的数字保留2位小数不要包含其他文字。例如0.85"""
messages = [
{"role": "system", "content": "你是一个专业的文本相似度评估专家,请准确评估两个文本的语义相似度。"},
{"role": "user", "content": prompt}
]
result = self.llm_client.chat_completion(
messages=messages,
temperature=0.1, # 低温度以获得更稳定的结果
max_tokens=50
)
if "error" in result:
logger.error(f"LLM API调用失败: {result['error']}")
# 回退到TF-IDF
return self._calculate_tfidf_similarity(text1, text2)
# 提取响应中的数字
response_content = result.get("choices", [{}])[0].get("message", {}).get("content", "")
similarity = self._extract_similarity_from_response(response_content)
logger.debug(f"LLM计算语义相似度: {similarity:.4f}")
return similarity
except Exception as e:
logger.error(f"LLM语义相似度计算失败: {e}")
# 回退到TF-IDF
return self._calculate_tfidf_similarity(text1, text2)
def _extract_similarity_from_response(self, response: str) -> float:
"""从LLM响应中提取相似度分数"""
try:
# 尝试提取0-1之间的浮点数
patterns = [
r'(\d+\.\d{1,2})', # 匹配两位小数的浮点数
r'(\d+\.\d+)', # 匹配任意小数的浮点数
r'(\d+)' # 匹配整数(可能是百分比形式)
]
for pattern in patterns:
matches = re.findall(pattern, response)
if matches:
value = float(matches[0])
# 如果值大于1可能是百分比形式需要除以100
if value > 1:
value = value / 100.0
# 确保在0-1范围内
value = max(0.0, min(1.0, value))
return value
# 如果没有找到数字,返回默认值
logger.warning(f"无法从响应中提取相似度分数: {response}")
return 0.5
except Exception as e:
logger.error(f"提取相似度分数失败: {e}, 响应: {response}")
return 0.5
def _calculate_semantic_similarity(self, text1: str, text2: str) -> float:
"""使用sentence-transformers计算语义相似度"""
try:
@@ -159,6 +244,11 @@ class SemanticSimilarityCalculator:
return []
try:
# 优先使用LLM API
if self.use_llm and self.llm_client:
return [self._calculate_llm_similarity(t1, t2) for t1, t2 in text_pairs]
# 回退到本地模型或TF-IDF
if self.model is not None:
return self._batch_semantic_similarity(text_pairs)
else:
@@ -214,17 +304,24 @@ class SemanticSimilarityCalculator:
return "语义差异较大,建议重新生成"
def is_model_available(self) -> bool:
"""检查模型是否可用"""
return self.model is not None
"""检查模型是否可用LLM或本地模型"""
if self.use_llm:
return self.llm_client is not None
else:
return self.model is not None
# 全局实例
_similarity_calculator = None
def get_similarity_calculator() -> SemanticSimilarityCalculator:
"""获取全局相似度计算器实例"""
def get_similarity_calculator(use_llm: bool = True) -> SemanticSimilarityCalculator:
"""获取全局相似度计算器实例
Args:
use_llm: 是否使用LLM API默认True推荐
"""
global _similarity_calculator
if _similarity_calculator is None:
_similarity_calculator = SemanticSimilarityCalculator()
_similarity_calculator = SemanticSimilarityCalculator(use_llm=use_llm)
return _similarity_calculator
def calculate_semantic_similarity(text1: str, text2: str, fast_mode: bool = True) -> float: