大改,未验证
This commit is contained in:
81
scripts/migrate_embeddings.py
Normal file
81
scripts/migrate_embeddings.py
Normal file
@@ -0,0 +1,81 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
批量为已有知识库条目生成 Embedding 向量(本地模型)
|
||||
运行方式: python scripts/migrate_embeddings.py
|
||||
|
||||
首次运行会自动下载模型(~95MB),之后走本地缓存
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from src.config.unified_config import get_config
|
||||
from src.core.database import db_manager
|
||||
from src.core.models import KnowledgeEntry
|
||||
from src.core.embedding_client import EmbeddingClient
|
||||
from src.core.vector_store import vector_store
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def migrate():
|
||||
config = get_config()
|
||||
if not config.embedding.enabled:
|
||||
logger.warning("Embedding 功能未启用,请在 .env 中设置 EMBEDDING_ENABLED=True")
|
||||
return
|
||||
|
||||
client = EmbeddingClient()
|
||||
|
||||
# 测试模型加载
|
||||
logger.info("正在加载本地 embedding 模型(首次运行需下载)...")
|
||||
if not client.test_connection():
|
||||
logger.error("Embedding 模型加载失败,请检查 sentence-transformers 是否安装")
|
||||
return
|
||||
logger.info("模型加载成功")
|
||||
|
||||
# 获取所有需要生成 embedding 的条目
|
||||
with db_manager.get_session() as session:
|
||||
entries = session.query(KnowledgeEntry).filter(
|
||||
KnowledgeEntry.is_active == True
|
||||
).all()
|
||||
|
||||
# 筛选出没有 embedding 的条目
|
||||
to_process = []
|
||||
for entry in entries:
|
||||
if not entry.vector_embedding or entry.vector_embedding.strip() == '':
|
||||
to_process.append(entry)
|
||||
|
||||
logger.info(f"共 {len(entries)} 条活跃知识,{len(to_process)} 条需要生成 embedding")
|
||||
|
||||
if not to_process:
|
||||
logger.info("所有条目已有 embedding,无需迁移")
|
||||
return
|
||||
|
||||
# 批量生成
|
||||
texts = [e.question + " " + e.answer for e in to_process]
|
||||
logger.info(f"开始批量生成 embedding...")
|
||||
vectors = client.embed_batch(texts)
|
||||
|
||||
success_count = 0
|
||||
for i, entry in enumerate(to_process):
|
||||
vec = vectors[i]
|
||||
if vec:
|
||||
entry.vector_embedding = json.dumps(vec)
|
||||
success_count += 1
|
||||
|
||||
session.commit()
|
||||
logger.info(f"Embedding 生成完成: 成功 {success_count}/{len(to_process)}")
|
||||
|
||||
# 重建向量索引
|
||||
vector_store.load_from_db()
|
||||
logger.info(f"向量索引重建完成: {vector_store.size} 条")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
migrate()
|
||||
Reference in New Issue
Block a user