82 lines
2.6 KiB
Python
82 lines
2.6 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
批量为已有知识库条目生成 Embedding 向量(本地模型)
|
||
运行方式: python scripts/migrate_embeddings.py
|
||
|
||
首次运行会自动下载模型(~95MB),之后走本地缓存
|
||
"""
|
||
|
||
import sys
|
||
import os
|
||
import json
|
||
import logging
|
||
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from src.config.unified_config import get_config
|
||
from src.core.database import db_manager
|
||
from src.core.models import KnowledgeEntry
|
||
from src.core.embedding_client import EmbeddingClient
|
||
from src.core.vector_store import vector_store
|
||
|
||
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def migrate():
|
||
config = get_config()
|
||
if not config.embedding.enabled:
|
||
logger.warning("Embedding 功能未启用,请在 .env 中设置 EMBEDDING_ENABLED=True")
|
||
return
|
||
|
||
client = EmbeddingClient()
|
||
|
||
# 测试模型加载
|
||
logger.info("正在加载本地 embedding 模型(首次运行需下载)...")
|
||
if not client.test_connection():
|
||
logger.error("Embedding 模型加载失败,请检查 sentence-transformers 是否安装")
|
||
return
|
||
logger.info("模型加载成功")
|
||
|
||
# 获取所有需要生成 embedding 的条目
|
||
with db_manager.get_session() as session:
|
||
entries = session.query(KnowledgeEntry).filter(
|
||
KnowledgeEntry.is_active == True
|
||
).all()
|
||
|
||
# 筛选出没有 embedding 的条目
|
||
to_process = []
|
||
for entry in entries:
|
||
if not entry.vector_embedding or entry.vector_embedding.strip() == '':
|
||
to_process.append(entry)
|
||
|
||
logger.info(f"共 {len(entries)} 条活跃知识,{len(to_process)} 条需要生成 embedding")
|
||
|
||
if not to_process:
|
||
logger.info("所有条目已有 embedding,无需迁移")
|
||
return
|
||
|
||
# 批量生成
|
||
texts = [e.question + " " + e.answer for e in to_process]
|
||
logger.info(f"开始批量生成 embedding...")
|
||
vectors = client.embed_batch(texts)
|
||
|
||
success_count = 0
|
||
for i, entry in enumerate(to_process):
|
||
vec = vectors[i]
|
||
if vec:
|
||
entry.vector_embedding = json.dumps(vec)
|
||
success_count += 1
|
||
|
||
session.commit()
|
||
logger.info(f"Embedding 生成完成: 成功 {success_count}/{len(to_process)}")
|
||
|
||
# 重建向量索引
|
||
vector_store.load_from_db()
|
||
logger.info(f"向量索引重建完成: {vector_store.size} 条")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
migrate()
|