Files
assist/scripts/migrate_embeddings.py
2026-03-20 16:50:26 +08:00

82 lines
2.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
批量为已有知识库条目生成 Embedding 向量(本地模型)
运行方式: python scripts/migrate_embeddings.py
首次运行会自动下载模型(~95MB之后走本地缓存
"""
import sys
import os
import json
import logging
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.config.unified_config import get_config
from src.core.database import db_manager
from src.core.models import KnowledgeEntry
from src.core.embedding_client import EmbeddingClient
from src.core.vector_store import vector_store
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)
def migrate():
config = get_config()
if not config.embedding.enabled:
logger.warning("Embedding 功能未启用,请在 .env 中设置 EMBEDDING_ENABLED=True")
return
client = EmbeddingClient()
# 测试模型加载
logger.info("正在加载本地 embedding 模型(首次运行需下载)...")
if not client.test_connection():
logger.error("Embedding 模型加载失败,请检查 sentence-transformers 是否安装")
return
logger.info("模型加载成功")
# 获取所有需要生成 embedding 的条目
with db_manager.get_session() as session:
entries = session.query(KnowledgeEntry).filter(
KnowledgeEntry.is_active == True
).all()
# 筛选出没有 embedding 的条目
to_process = []
for entry in entries:
if not entry.vector_embedding or entry.vector_embedding.strip() == '':
to_process.append(entry)
logger.info(f"{len(entries)} 条活跃知识,{len(to_process)} 条需要生成 embedding")
if not to_process:
logger.info("所有条目已有 embedding无需迁移")
return
# 批量生成
texts = [e.question + " " + e.answer for e in to_process]
logger.info(f"开始批量生成 embedding...")
vectors = client.embed_batch(texts)
success_count = 0
for i, entry in enumerate(to_process):
vec = vectors[i]
if vec:
entry.vector_embedding = json.dumps(vec)
success_count += 1
session.commit()
logger.info(f"Embedding 生成完成: 成功 {success_count}/{len(to_process)}")
# 重建向量索引
vector_store.load_from_db()
logger.info(f"向量索引重建完成: {vector_store.size}")
if __name__ == "__main__":
migrate()