Files
assist/src/web/websocket_server.py

298 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""
WebSocket实时通信服务器
提供实时对话功能
"""
import asyncio
import json
import logging
from datetime import datetime
from typing import Dict, Set
import websockets
from websockets.server import WebSocketServerProtocol
from ..dialogue.realtime_chat import RealtimeChatManager
logger = logging.getLogger(__name__)
class WebSocketServer:
"""WebSocket服务器"""
def __init__(self, host: str = "localhost", port: int = 8765):
self.host = host
self.port = port
self.chat_manager = RealtimeChatManager()
self.connected_clients: Set[WebSocketServerProtocol] = set()
async def register_client(self, websocket: WebSocketServerProtocol):
"""注册客户端"""
self.connected_clients.add(websocket)
logger.info(f"客户端连接: {websocket.remote_address}")
async def unregister_client(self, websocket: WebSocketServerProtocol):
"""注销客户端"""
self.connected_clients.discard(websocket)
logger.info(f"客户端断开: {websocket.remote_address}")
async def handle_message(self, websocket: WebSocketServerProtocol, message: str):
"""处理客户端消息"""
try:
data = json.loads(message)
message_type = data.get("type")
message_id = data.get("messageId") # 获取消息ID
if message_type == "create_session":
await self._handle_create_session(websocket, data, message_id)
elif message_type == "send_message":
await self._handle_send_message(websocket, data, message_id)
elif message_type == "get_history":
await self._handle_get_history(websocket, data, message_id)
elif message_type == "create_work_order":
await self._handle_create_work_order(websocket, data, message_id)
elif message_type == "get_work_order_status":
await self._handle_get_work_order_status(websocket, data, message_id)
elif message_type == "end_session":
await self._handle_end_session(websocket, data, message_id)
else:
await self._send_error(websocket, "未知消息类型", message_id)
except json.JSONDecodeError:
await self._send_error(websocket, "JSON格式错误")
except Exception as e:
logger.error(f"处理消息失败: {e}")
await self._send_error(websocket, f"处理消息失败: {str(e)}")
async def _handle_create_session(self, websocket: WebSocketServerProtocol, data: Dict, message_id: str = None):
"""处理创建会话请求"""
user_id = data.get("user_id", "anonymous")
work_order_id = data.get("work_order_id")
session_id = self.chat_manager.create_session(user_id, work_order_id)
response = {
"type": "session_created",
"session_id": session_id,
"timestamp": datetime.now().isoformat()
}
if message_id:
response["messageId"] = message_id
await websocket.send(json.dumps(response, ensure_ascii=False))
async def _handle_send_message(self, websocket: WebSocketServerProtocol, data: Dict, message_id: str = None):
"""处理发送消息请求"""
session_id = data.get("session_id")
message = data.get("message")
if not session_id or not message:
await self._send_error(websocket, "缺少必要参数", message_id)
return
# 处理消息
result = self.chat_manager.process_message(session_id, message)
response = {
"type": "message_response",
"session_id": session_id,
"result": result,
"timestamp": datetime.now().isoformat()
}
if message_id:
response["messageId"] = message_id
await websocket.send(json.dumps(response, ensure_ascii=False))
async def _handle_get_history(self, websocket: WebSocketServerProtocol, data: Dict, message_id: str = None):
"""处理获取历史记录请求"""
session_id = data.get("session_id")
if not session_id:
await self._send_error(websocket, "缺少会话ID", message_id)
return
history = self.chat_manager.get_session_history(session_id)
response = {
"type": "history_response",
"session_id": session_id,
"history": history,
"timestamp": datetime.now().isoformat()
}
if message_id:
response["messageId"] = message_id
await websocket.send(json.dumps(response, ensure_ascii=False))
async def _handle_create_work_order(self, websocket: WebSocketServerProtocol, data: Dict, message_id: str = None):
"""处理创建工单请求"""
session_id = data.get("session_id")
title = data.get("title")
description = data.get("description")
category = data.get("category", "技术问题")
priority = data.get("priority", "medium")
if not session_id or not title or not description:
await self._send_error(websocket, "缺少必要参数", message_id)
return
result = self.chat_manager.create_work_order(session_id, title, description, category, priority)
response = {
"type": "work_order_created",
"session_id": session_id,
"result": result,
"timestamp": datetime.now().isoformat()
}
if message_id:
response["messageId"] = message_id
await websocket.send(json.dumps(response, ensure_ascii=False))
async def _handle_get_work_order_status(self, websocket: WebSocketServerProtocol, data: Dict, message_id: str = None):
"""处理获取工单状态请求"""
work_order_id = data.get("work_order_id")
if not work_order_id:
await self._send_error(websocket, "缺少工单ID", message_id)
return
result = self.chat_manager.get_work_order_status(work_order_id)
response = {
"type": "work_order_status",
"work_order_id": work_order_id,
"result": result,
"timestamp": datetime.now().isoformat()
}
if message_id:
response["messageId"] = message_id
await websocket.send(json.dumps(response, ensure_ascii=False))
async def _handle_end_session(self, websocket: WebSocketServerProtocol, data: Dict, message_id: str = None):
"""处理结束会话请求"""
session_id = data.get("session_id")
if not session_id:
await self._send_error(websocket, "缺少会话ID", message_id)
return
success = self.chat_manager.end_session(session_id)
response = {
"type": "session_ended",
"session_id": session_id,
"success": success,
"timestamp": datetime.now().isoformat()
}
if message_id:
response["messageId"] = message_id
await websocket.send(json.dumps(response, ensure_ascii=False))
async def _send_error(self, websocket: WebSocketServerProtocol, error_message: str, message_id: str = None):
"""发送错误消息"""
response = {
"type": "error",
"message": error_message,
"timestamp": datetime.now().isoformat()
}
if message_id:
response["messageId"] = message_id
await websocket.send(json.dumps(response, ensure_ascii=False))
async def handle_client(self, websocket: WebSocketServerProtocol, path: str):
"""处理客户端连接"""
# 检查连接头(如果可用)
try:
if hasattr(websocket, 'request_headers'):
headers = websocket.request_headers
connection = headers.get("Connection", "").lower()
# 处理不同的连接头格式
if "upgrade" not in connection and "keep-alive" in connection:
logger.warning(f"收到非标准连接头: {connection}")
# 对于keep-alive连接头我们仍然接受连接
elif "upgrade" not in connection:
logger.warning(f"连接头不包含upgrade: {connection}")
# 在websockets 15.x中连接已经在serve时验证所以这里只记录警告
except AttributeError:
# websockets 15.x版本可能没有request_headers属性跳过检查
pass
await self.register_client(websocket)
try:
async for message in websocket:
await self.handle_message(websocket, message)
except websockets.exceptions.ConnectionClosed:
pass
except Exception as e:
logger.error(f"WebSocket连接错误: {e}")
finally:
await self.unregister_client(websocket)
async def start_server(self):
"""启动WebSocket服务器"""
logger.info(f"启动WebSocket服务器: ws://{self.host}:{self.port}")
# 添加CORS支持
async def handle_client_with_cors(websocket: WebSocketServerProtocol, path: str = None):
# CORS处理websockets库默认允许所有来源连接
# 如果需要限制可以在serve时使用additional_headers参数
await self.handle_client(websocket, path or "")
async with websockets.serve(
handle_client_with_cors,
self.host,
self.port
):
await asyncio.Future() # 保持服务器运行
def _process_request(self, path, request):
"""处理HTTP请求支持CORS"""
# 检查是否是WebSocket升级请求
if request.headers.get("Upgrade", "").lower() == "websocket":
return None # 允许WebSocket连接
# 对于非WebSocket请求返回简单的HTML页面
return (
200,
[("Content-Type", "text/html; charset=utf-8")],
b"""
<!DOCTYPE html>
<html>
<head>
<title>WebSocket Server</title>
</head>
<body>
<h1>WebSocket Server is running</h1>
<p>This is a WebSocket server. Please use a WebSocket client to connect.</p>
<p>WebSocket URL: ws://localhost:8765</p>
</body>
</html>
"""
)
def run(self):
"""运行服务器"""
asyncio.run(self.start_server())
if __name__ == "__main__":
# 设置日志
logging.basicConfig(level=logging.INFO)
# 启动服务器
server = WebSocketServer()
server.run()