82 lines
2.6 KiB
Python
82 lines
2.6 KiB
Python
|
|
#!/usr/bin/env python
|
|||
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
"""
|
|||
|
|
批量为已有知识库条目生成 Embedding 向量(本地模型)
|
|||
|
|
运行方式: python scripts/migrate_embeddings.py
|
|||
|
|
|
|||
|
|
首次运行会自动下载模型(~95MB),之后走本地缓存
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import sys
|
|||
|
|
import os
|
|||
|
|
import json
|
|||
|
|
import logging
|
|||
|
|
|
|||
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|||
|
|
|
|||
|
|
from src.config.unified_config import get_config
|
|||
|
|
from src.core.database import db_manager
|
|||
|
|
from src.core.models import KnowledgeEntry
|
|||
|
|
from src.core.embedding_client import EmbeddingClient
|
|||
|
|
from src.core.vector_store import vector_store
|
|||
|
|
|
|||
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def migrate():
|
|||
|
|
config = get_config()
|
|||
|
|
if not config.embedding.enabled:
|
|||
|
|
logger.warning("Embedding 功能未启用,请在 .env 中设置 EMBEDDING_ENABLED=True")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
client = EmbeddingClient()
|
|||
|
|
|
|||
|
|
# 测试模型加载
|
|||
|
|
logger.info("正在加载本地 embedding 模型(首次运行需下载)...")
|
|||
|
|
if not client.test_connection():
|
|||
|
|
logger.error("Embedding 模型加载失败,请检查 sentence-transformers 是否安装")
|
|||
|
|
return
|
|||
|
|
logger.info("模型加载成功")
|
|||
|
|
|
|||
|
|
# 获取所有需要生成 embedding 的条目
|
|||
|
|
with db_manager.get_session() as session:
|
|||
|
|
entries = session.query(KnowledgeEntry).filter(
|
|||
|
|
KnowledgeEntry.is_active == True
|
|||
|
|
).all()
|
|||
|
|
|
|||
|
|
# 筛选出没有 embedding 的条目
|
|||
|
|
to_process = []
|
|||
|
|
for entry in entries:
|
|||
|
|
if not entry.vector_embedding or entry.vector_embedding.strip() == '':
|
|||
|
|
to_process.append(entry)
|
|||
|
|
|
|||
|
|
logger.info(f"共 {len(entries)} 条活跃知识,{len(to_process)} 条需要生成 embedding")
|
|||
|
|
|
|||
|
|
if not to_process:
|
|||
|
|
logger.info("所有条目已有 embedding,无需迁移")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 批量生成
|
|||
|
|
texts = [e.question + " " + e.answer for e in to_process]
|
|||
|
|
logger.info(f"开始批量生成 embedding...")
|
|||
|
|
vectors = client.embed_batch(texts)
|
|||
|
|
|
|||
|
|
success_count = 0
|
|||
|
|
for i, entry in enumerate(to_process):
|
|||
|
|
vec = vectors[i]
|
|||
|
|
if vec:
|
|||
|
|
entry.vector_embedding = json.dumps(vec)
|
|||
|
|
success_count += 1
|
|||
|
|
|
|||
|
|
session.commit()
|
|||
|
|
logger.info(f"Embedding 生成完成: 成功 {success_count}/{len(to_process)}")
|
|||
|
|
|
|||
|
|
# 重建向量索引
|
|||
|
|
vector_store.load_from_db()
|
|||
|
|
logger.info(f"向量索引重建完成: {vector_store.size} 条")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
migrate()
|