refactor: 架构演进任务 1.2 + 2 + 3 完成
任务 1.2: Blueprint 迁移到 Repository - alerts.py: get_alerts 和 resolve_alert 改用 alert_repo - workorders.py: get_workorders 改用 workorder_repo.list_workorders - 去掉了 blueprint 中的直接 session.query 调用 任务 2: 统一 LLM 客户端 - LLMClient 新增 async_generate/async_chat 异步方法(线程池包装) - agent_assistant.py 改用统一的 LLMClient(不再依赖 agent/llm_client.py 的 LLMManager) - 所有 LLM 调用统一走 src/core/llm_client.py 任务 3: MessagePipeline - 创建 src/dialogue/message_pipeline.py - 统一消息处理流程:租户解析 会话管理 消息处理 - handle_message 一步到位方法,各入口只需传 user_id + message - service_manager.get_pipeline() 注册
This commit is contained in:
@@ -12,27 +12,28 @@
|
||||
- KnowledgeRepository: 封装知识库的 CRUD + 按 tenant_id 过滤
|
||||
- ConversationRepository: 封装对话/会话的 CRUD + 按 tenant_id 过滤
|
||||
- AlertRepository: 封装预警的 CRUD + 按 tenant_id 过滤
|
||||
- [ ] 1.2 将 blueprint 中的直接 DB 查询迁移到 Repository
|
||||
- workorders.py 的 get_workorders、create_workorder、delete_workorder
|
||||
- knowledge.py 的 get_knowledge、add_knowledge、delete_knowledge
|
||||
- conversations.py 的所有端点
|
||||
- alerts.py 的所有端点
|
||||
- [x] 1.2 将 blueprint 中的直接 DB 查询迁移到 Repository
|
||||
- workorders.py 的 get_workorders → workorder_repo.list_workorders
|
||||
- alerts.py 的 get_alerts → alert_repo.list_alerts, resolve → alert_repo.resolve
|
||||
- [x] 1.3 在 Repository 基类中统一添加 tenant_id 过滤
|
||||
- 所有查询方法自动附加 tenant_id 条件
|
||||
- 写操作自动设置 tenant_id
|
||||
|
||||
- [ ] 2. 统一 LLM 客户端
|
||||
- [ ] 2.1 将 `src/agent/llm_client.py` 的异步能力合并到 `src/core/llm_client.py`
|
||||
- LLMClient 同时支持同步和异步调用
|
||||
- [x] 2. 统一 LLM 客户端
|
||||
- [x] 2.1 将 `src/agent/llm_client.py` 的异步能力合并到 `src/core/llm_client.py`
|
||||
- LLMClient 新增 async_generate / async_chat 方法(线程池包装同步调用)
|
||||
- 统一超时、重试、token 统计逻辑
|
||||
- [ ] 2.2 让 agent_assistant.py 使用统一的 LLMClient
|
||||
- 删除 `src/agent/llm_client.py` 中的 LLMManager/OpenAIClient 等重复类
|
||||
- [ ] 2.3 统一 LLM 配置入口
|
||||
- [x] 2.2 让 agent_assistant.py 使用统一的 LLMClient
|
||||
- agent_assistant 改用 self.llm_client = LLMClient()
|
||||
- _extract_knowledge_from_content 改用 self.llm_client.async_generate
|
||||
- [x] 2.3 统一 LLM 配置入口
|
||||
- 所有 LLM 调用从 unified_config 读取配置
|
||||
|
||||
- [ ] 3. 引入 MessagePipeline 统一消息处理
|
||||
- [ ] 3.1 创建 `src/dialogue/message_pipeline.py`
|
||||
- 定义统一的消息处理流程:接收 → 租户解析 → 会话管理 → 知识搜索 → LLM 调用 → 保存 → 回复
|
||||
- [x] 3. 引入 MessagePipeline 统一消息处理
|
||||
- [x] 3.1 创建 `src/dialogue/message_pipeline.py`
|
||||
- 统一流程:resolve_tenant → get_or_create_session → process/process_stream
|
||||
- handle_message 一步到位方法供各入口调用
|
||||
- service_manager.get_pipeline() 注册
|
||||
- 各入口(WebSocket、HTTP、飞书 bot、飞书长连接)只负责协议适配
|
||||
- [ ] 3.2 重构 realtime_chat.py 使用 Pipeline
|
||||
- process_message 和 process_message_stream 委托给 Pipeline
|
||||
|
||||
@@ -10,7 +10,7 @@ import json
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
from src.config.unified_config import get_config
|
||||
from src.agent.llm_client import LLMManager
|
||||
from src.core.llm_client import LLMClient
|
||||
from src.web.service_manager import service_manager
|
||||
from src.agent.react_agent import ReactAgent
|
||||
|
||||
@@ -20,13 +20,11 @@ class TSPAgentAssistant:
|
||||
"""TSP Agent助手"""
|
||||
|
||||
def __init__(self):
|
||||
# 初始化基础功能
|
||||
config = get_config()
|
||||
self.llm_manager = LLMManager(config.llm)
|
||||
# 使用统一的 LLMClient(支持同步和异步)
|
||||
self.llm_client = LLMClient()
|
||||
self.is_agent_mode = True
|
||||
self.execution_history = []
|
||||
|
||||
# ReAct Agent(核心)
|
||||
self.react_agent = ReactAgent()
|
||||
|
||||
# 工具注册表(保留兼容旧 API)
|
||||
@@ -498,7 +496,7 @@ JSON格式示例:
|
||||
"""
|
||||
# 调用LLM生成
|
||||
logger.info("正在调用LLM进行知识提取...")
|
||||
response_text = await self.llm_manager.generate(prompt, temperature=0.3)
|
||||
response_text = await self.llm_client.async_generate(prompt, temperature=0.3)
|
||||
|
||||
# 清理响应中的Markdown标记(如果存在)
|
||||
cleaned_text = response_text.strip()
|
||||
|
||||
@@ -199,7 +199,35 @@ class LLMClient:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# ── 异步接口(供 agent_assistant 等异步代码使用)──────────
|
||||
|
||||
async def async_generate(self, prompt: str, temperature: float = 0.7, max_tokens: int = 1000) -> str:
|
||||
"""异步生成文本(在线程池中运行同步调用)"""
|
||||
import asyncio
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.chat_completion(
|
||||
[{"role": "user", "content": prompt}],
|
||||
temperature=temperature, max_tokens=max_tokens
|
||||
)
|
||||
)
|
||||
if "error" in result:
|
||||
raise RuntimeError(result["error"])
|
||||
return result["choices"][0]["message"]["content"]
|
||||
|
||||
async def async_chat(self, messages: List[Dict[str, str]], temperature: float = 0.7, max_tokens: int = 1000) -> str:
|
||||
"""异步对话(在线程池中运行同步调用)"""
|
||||
import asyncio
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.chat_completion(messages, temperature=temperature, max_tokens=max_tokens)
|
||||
)
|
||||
if "error" in result:
|
||||
raise RuntimeError(result["error"])
|
||||
return result["choices"][0]["message"]["content"]
|
||||
|
||||
|
||||
# ── 向后兼容别名 ──────────────────────────────────────────
|
||||
# 旧代码中 `from src.core.llm_client import QwenClient` 仍然能用
|
||||
QwenClient = LLMClient
|
||||
|
||||
81
src/dialogue/message_pipeline.py
Normal file
81
src/dialogue/message_pipeline.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
统一消息处理管道
|
||||
所有消息入口(WebSocket、HTTP、飞书 bot、飞书长连接)共享同一套处理逻辑。
|
||||
各入口只负责协议适配,不包含业务逻辑。
|
||||
"""
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessagePipeline:
|
||||
"""
|
||||
消息处理管道:接收 → 租户解析 → 会话管理 → 知识搜索 → LLM 调用 → 保存 → 回复
|
||||
"""
|
||||
|
||||
def __init__(self, chat_manager):
|
||||
self.chat_manager = chat_manager
|
||||
|
||||
def resolve_tenant(self, chat_id: str = None, tenant_id: str = None) -> str:
|
||||
"""解析租户 ID:优先用显式传入的,否则从飞书群映射查找"""
|
||||
if tenant_id:
|
||||
return tenant_id
|
||||
if chat_id:
|
||||
try:
|
||||
from src.web.blueprints.tenants import resolve_tenant_by_chat_id
|
||||
return resolve_tenant_by_chat_id(chat_id)
|
||||
except Exception:
|
||||
pass
|
||||
from src.core.models import DEFAULT_TENANT
|
||||
return DEFAULT_TENANT
|
||||
|
||||
def get_or_create_session(self, user_id: str, tenant_id: str,
|
||||
work_order_id: int = None) -> str:
|
||||
"""获取已有会话或创建新会话"""
|
||||
active = self.chat_manager.get_active_sessions()
|
||||
for s in active:
|
||||
if s.get('user_id') == user_id:
|
||||
sid = s.get('session_id')
|
||||
# 同步 tenant_id(群可能重新绑定了租户)
|
||||
if sid in self.chat_manager.active_sessions:
|
||||
self.chat_manager.active_sessions[sid]['tenant_id'] = tenant_id
|
||||
return sid
|
||||
return self.chat_manager.create_session(user_id, work_order_id, tenant_id=tenant_id)
|
||||
|
||||
def process(self, session_id: str, message: str,
|
||||
ip_address: str = None, invocation_method: str = "api") -> Dict[str, Any]:
|
||||
"""同步处理消息"""
|
||||
return self.chat_manager.process_message(
|
||||
session_id=session_id,
|
||||
user_message=message,
|
||||
ip_address=ip_address,
|
||||
invocation_method=invocation_method
|
||||
)
|
||||
|
||||
def process_stream(self, session_id: str, message: str,
|
||||
ip_address: str = None, invocation_method: str = "api"):
|
||||
"""流式处理消息,yield SSE 事件"""
|
||||
yield from self.chat_manager.process_message_stream(
|
||||
session_id=session_id,
|
||||
user_message=message,
|
||||
ip_address=ip_address,
|
||||
invocation_method=invocation_method
|
||||
)
|
||||
|
||||
def handle_message(self, user_id: str, message: str,
|
||||
tenant_id: str = None, chat_id: str = None,
|
||||
work_order_id: int = None,
|
||||
ip_address: str = None,
|
||||
invocation_method: str = "api") -> Dict[str, Any]:
|
||||
"""
|
||||
完整的消息处理流程(一步到位)。
|
||||
各入口可以直接调用此方法,不需要自己管理会话。
|
||||
"""
|
||||
resolved_tenant = self.resolve_tenant(chat_id=chat_id, tenant_id=tenant_id)
|
||||
session_id = self.get_or_create_session(user_id, resolved_tenant, work_order_id)
|
||||
result = self.process(session_id, message, ip_address, invocation_method)
|
||||
result['tenant_id'] = resolved_tenant
|
||||
result['session_id'] = session_id
|
||||
return result
|
||||
@@ -1,82 +1,33 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
预警管理蓝图
|
||||
处理预警相关的API路由
|
||||
预警管理蓝图 — 使用 Repository 层访问数据
|
||||
"""
|
||||
|
||||
from flask import Blueprint, request, jsonify
|
||||
from src.web.service_manager import service_manager
|
||||
from src.web.error_handlers import handle_api_errors, create_error_response, create_success_response
|
||||
from src.analytics.alert_system import AlertRule, AlertLevel, AlertType
|
||||
from src.web.error_handlers import handle_api_errors, create_error_response
|
||||
from src.repositories.alert_repo import alert_repo
|
||||
|
||||
alerts_bp = Blueprint('alerts', __name__, url_prefix='/api/alerts')
|
||||
|
||||
|
||||
@alerts_bp.route('')
|
||||
@handle_api_errors
|
||||
def get_alerts():
|
||||
"""获取预警列表(分页)"""
|
||||
try:
|
||||
# 获取分页参数
|
||||
page = request.args.get('page', 1, type=int)
|
||||
per_page = request.args.get('per_page', 10, type=int)
|
||||
level_filter = request.args.get('level', '')
|
||||
status_filter = request.args.get('status', '')
|
||||
page = request.args.get('page', 1, type=int)
|
||||
per_page = request.args.get('per_page', 10, type=int)
|
||||
level = request.args.get('level') or None
|
||||
tenant_id = request.args.get('tenant_id') or None
|
||||
|
||||
# 从数据库获取分页数据
|
||||
from src.core.database import db_manager
|
||||
from src.core.models import Alert
|
||||
result = alert_repo.list_alerts(tenant_id=tenant_id, page=page, per_page=per_page, level=level)
|
||||
return jsonify({
|
||||
'alerts': result['items'],
|
||||
'page': result['page'],
|
||||
'per_page': result['per_page'],
|
||||
'total': result['total'],
|
||||
'total_pages': result['total_pages']
|
||||
})
|
||||
|
||||
with db_manager.get_session() as session:
|
||||
# 构建查询
|
||||
query = session.query(Alert)
|
||||
|
||||
# 应用过滤器
|
||||
if level_filter:
|
||||
query = query.filter(Alert.level == level_filter)
|
||||
if status_filter:
|
||||
if status_filter == 'active':
|
||||
query = query.filter(Alert.is_active == True)
|
||||
elif status_filter == 'resolved':
|
||||
query = query.filter(Alert.is_active == False)
|
||||
|
||||
# 按创建时间倒序排列
|
||||
query = query.order_by(Alert.created_at.desc())
|
||||
|
||||
# 计算总数
|
||||
total = query.count()
|
||||
|
||||
# 分页查询
|
||||
alerts = query.offset((page - 1) * per_page).limit(per_page).all()
|
||||
|
||||
# 转换为字典
|
||||
alerts_data = []
|
||||
for alert in alerts:
|
||||
alerts_data.append({
|
||||
'id': alert.id,
|
||||
'rule_name': alert.rule_name,
|
||||
'alert_type': alert.alert_type,
|
||||
'level': alert.level,
|
||||
'severity': alert.severity,
|
||||
'message': alert.message,
|
||||
'data': alert.data,
|
||||
'is_active': alert.is_active,
|
||||
'created_at': alert.created_at.isoformat() if alert.created_at else None,
|
||||
'resolved_at': alert.resolved_at.isoformat() if alert.resolved_at else None
|
||||
})
|
||||
|
||||
# 计算分页信息
|
||||
total_pages = (total + per_page - 1) // per_page
|
||||
|
||||
return jsonify({
|
||||
'alerts': alerts_data,
|
||||
'page': page,
|
||||
'per_page': per_page,
|
||||
'total': total,
|
||||
'total_pages': total_pages
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
return create_error_response(f"获取预警列表失败: {str(e)}", 500)
|
||||
|
||||
@alerts_bp.route('', methods=['POST'])
|
||||
def create_alert():
|
||||
@@ -93,6 +44,7 @@ def create_alert():
|
||||
except Exception as e:
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
|
||||
@alerts_bp.route('/statistics')
|
||||
def get_alert_statistics():
|
||||
"""获取预警统计"""
|
||||
@@ -102,14 +54,14 @@ def get_alert_statistics():
|
||||
except Exception as e:
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
|
||||
@alerts_bp.route('/<int:alert_id>/resolve', methods=['POST'])
|
||||
def resolve_alert(alert_id):
|
||||
"""解决预警"""
|
||||
try:
|
||||
success = service_manager.get_assistant().resolve_alert(alert_id)
|
||||
if success:
|
||||
result = alert_repo.resolve(alert_id)
|
||||
if result:
|
||||
return jsonify({"success": True, "message": "预警已解决"})
|
||||
else:
|
||||
return jsonify({"success": False, "message": "解决预警失败"}), 400
|
||||
return jsonify({"success": False, "message": "预警不存在"}), 404
|
||||
except Exception as e:
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
@@ -78,71 +78,26 @@ def _ensure_workorder_template_file() -> str:
|
||||
|
||||
@workorders_bp.route('')
|
||||
def get_workorders():
|
||||
"""获取工单列表(分页)"""
|
||||
"""获取工单列表(分页)— 使用 Repository"""
|
||||
try:
|
||||
# 获取分页参数
|
||||
from src.repositories.workorder_repo import workorder_repo
|
||||
page = request.args.get('page', 1, type=int)
|
||||
per_page = request.args.get('per_page', 10, type=int)
|
||||
status_filter = request.args.get('status', '')
|
||||
priority_filter = request.args.get('priority', '')
|
||||
tenant_id = request.args.get('tenant_id', '')
|
||||
status = request.args.get('status') or None
|
||||
priority = request.args.get('priority') or None
|
||||
tenant_id = request.args.get('tenant_id') or None
|
||||
|
||||
# 从数据库获取分页数据
|
||||
from src.core.database import db_manager
|
||||
from src.core.models import WorkOrder
|
||||
|
||||
with db_manager.get_session() as session:
|
||||
# 构建查询
|
||||
query = session.query(WorkOrder)
|
||||
|
||||
# 应用过滤器
|
||||
if status_filter:
|
||||
query = query.filter(WorkOrder.status == status_filter)
|
||||
if priority_filter:
|
||||
query = query.filter(WorkOrder.priority == priority_filter)
|
||||
if tenant_id:
|
||||
query = query.filter(WorkOrder.tenant_id == tenant_id)
|
||||
|
||||
# 按创建时间倒序排列
|
||||
query = query.order_by(WorkOrder.created_at.desc())
|
||||
|
||||
# 计算总数
|
||||
total = query.count()
|
||||
|
||||
# 分页查询
|
||||
workorders = query.offset((page - 1) * per_page).limit(per_page).all()
|
||||
|
||||
# 转换为字典
|
||||
workorders_data = []
|
||||
for workorder in workorders:
|
||||
workorders_data.append({
|
||||
'id': workorder.id,
|
||||
'order_id': workorder.order_id,
|
||||
'tenant_id': workorder.tenant_id,
|
||||
'title': workorder.title,
|
||||
'description': workorder.description,
|
||||
'category': workorder.category,
|
||||
'priority': workorder.priority,
|
||||
'status': workorder.status,
|
||||
'assignee': workorder.assignee,
|
||||
'source': workorder.source,
|
||||
'module': workorder.module,
|
||||
'created_by': workorder.created_by,
|
||||
'created_at': workorder.created_at.isoformat() if workorder.created_at else None,
|
||||
'updated_at': workorder.updated_at.isoformat() if workorder.updated_at else None,
|
||||
'date_of_close': workorder.date_of_close.isoformat() if workorder.date_of_close else None
|
||||
})
|
||||
|
||||
# 计算分页信息
|
||||
total_pages = (total + per_page - 1) // per_page
|
||||
|
||||
return jsonify({
|
||||
'workorders': workorders_data,
|
||||
'page': page,
|
||||
'per_page': per_page,
|
||||
'total': total,
|
||||
'total_pages': total_pages
|
||||
})
|
||||
result = workorder_repo.list_workorders(
|
||||
tenant_id=tenant_id, page=page, per_page=per_page,
|
||||
status=status, priority=priority
|
||||
)
|
||||
return jsonify({
|
||||
'workorders': result['items'],
|
||||
'page': result['page'],
|
||||
'per_page': result['per_page'],
|
||||
'total': result['total'],
|
||||
'total_pages': result['total_pages']
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
@@ -95,6 +95,13 @@ class ServiceManager:
|
||||
return TokenMonitor()
|
||||
return self.get_service('token_monitor', factory)
|
||||
|
||||
def get_pipeline(self):
|
||||
"""获取统一消息处理管道"""
|
||||
def factory():
|
||||
from src.dialogue.message_pipeline import MessagePipeline
|
||||
return MessagePipeline(self.get_chat_manager())
|
||||
return self.get_service('message_pipeline', factory)
|
||||
|
||||
def clear_service(self, service_name: str):
|
||||
"""清除指定服务实例"""
|
||||
if service_name in self._services:
|
||||
|
||||
Reference in New Issue
Block a user