Files
assist/scripts/migrate_embeddings.py

82 lines
2.6 KiB
Python
Raw Normal View History

2026-03-20 16:50:26 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
批量为已有知识库条目生成 Embedding 向量本地模型
运行方式: python scripts/migrate_embeddings.py
首次运行会自动下载模型~95MB之后走本地缓存
"""
import sys
import os
import json
import logging
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.config.unified_config import get_config
from src.core.database import db_manager
from src.core.models import KnowledgeEntry
from src.core.embedding_client import EmbeddingClient
from src.core.vector_store import vector_store
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)
def migrate():
config = get_config()
if not config.embedding.enabled:
logger.warning("Embedding 功能未启用,请在 .env 中设置 EMBEDDING_ENABLED=True")
return
client = EmbeddingClient()
# 测试模型加载
logger.info("正在加载本地 embedding 模型(首次运行需下载)...")
if not client.test_connection():
logger.error("Embedding 模型加载失败,请检查 sentence-transformers 是否安装")
return
logger.info("模型加载成功")
# 获取所有需要生成 embedding 的条目
with db_manager.get_session() as session:
entries = session.query(KnowledgeEntry).filter(
KnowledgeEntry.is_active == True
).all()
# 筛选出没有 embedding 的条目
to_process = []
for entry in entries:
if not entry.vector_embedding or entry.vector_embedding.strip() == '':
to_process.append(entry)
logger.info(f"{len(entries)} 条活跃知识,{len(to_process)} 条需要生成 embedding")
if not to_process:
logger.info("所有条目已有 embedding无需迁移")
return
# 批量生成
texts = [e.question + " " + e.answer for e in to_process]
logger.info(f"开始批量生成 embedding...")
vectors = client.embed_batch(texts)
success_count = 0
for i, entry in enumerate(to_process):
vec = vectors[i]
if vec:
entry.vector_embedding = json.dumps(vec)
success_count += 1
session.commit()
logger.info(f"Embedding 生成完成: 成功 {success_count}/{len(to_process)}")
# 重建向量索引
vector_store.load_from_db()
logger.info(f"向量索引重建完成: {vector_store.size}")
if __name__ == "__main__":
migrate()