Files
assist/scripts/update_manager.py

478 lines
18 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
TSP智能助手更新管理器
支持热更新版本管理回滚等功能
"""
import os
import sys
import json
import shutil
import subprocess
import time
import requests
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Tuple
class UpdateManager:
"""更新管理器"""
def __init__(self, config_file: str = "update_config.json"):
self.config_file = config_file
self.config = self._load_config()
self.version_manager = None
# 初始化版本管理器
try:
from version import VersionManager
self.version_manager = VersionManager()
except ImportError:
print("警告: 版本管理器不可用")
def _load_config(self) -> Dict:
"""加载更新配置"""
default_config = {
"app_name": "tsp_assistant",
"deploy_path": "/opt/tsp_assistant",
"backup_path": "./backups",
"service_name": "tsp_assistant",
"health_url": "http://localhost:5000/api/health",
"update_timeout": 300,
"rollback_enabled": True,
"auto_backup": True,
"hot_update_enabled": True,
"environments": {
"development": {
"path": "./dev_deploy",
"service_name": "",
"auto_restart": False
},
"staging": {
"path": "/opt/tsp_assistant_staging",
"service_name": "tsp_assistant_staging",
"auto_restart": True
},
"production": {
"path": "/opt/tsp_assistant",
"service_name": "tsp_assistant",
"auto_restart": True
}
}
}
if os.path.exists(self.config_file):
try:
with open(self.config_file, 'r', encoding='utf-8') as f:
config = json.load(f)
# 合并默认配置
default_config.update(config)
except Exception as e:
print(f"加载配置文件失败: {e}")
return default_config
def _save_config(self):
"""保存配置"""
try:
with open(self.config_file, 'w', encoding='utf-8') as f:
json.dump(self.config, f, indent=2, ensure_ascii=False)
except Exception as e:
print(f"保存配置文件失败: {e}")
def check_update_available(self, source_path: str) -> Tuple[bool, str, str]:
"""检查是否有更新可用"""
if not self.version_manager:
return False, "unknown", "unknown"
current_version = self.version_manager.get_version()
# 检查源路径的版本
try:
source_version_file = os.path.join(source_path, "version.json")
if os.path.exists(source_version_file):
with open(source_version_file, 'r', encoding='utf-8') as f:
source_info = json.load(f)
source_version = source_info.get("version", "unknown")
else:
return False, current_version, "unknown"
except Exception as e:
print(f"检查源版本失败: {e}")
return False, current_version, "unknown"
# 比较版本
if source_version != current_version:
return True, current_version, source_version
return False, current_version, source_version
def create_backup(self, environment: str = "production") -> str:
"""创建备份"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_name = f"{self.config['app_name']}_backup_{timestamp}"
backup_path = os.path.join(self.config["backup_path"], backup_name)
print(f"创建备份: {backup_name}")
# 创建备份目录
os.makedirs(backup_path, exist_ok=True)
# 获取部署路径
env_config = self.config["environments"].get(environment, {})
deploy_path = env_config.get("path", self.config["deploy_path"])
# 备份应用文件
if os.path.exists(deploy_path):
print("备份应用文件...")
shutil.copytree(deploy_path, os.path.join(backup_path, "app"))
# 备份数据库
db_file = os.path.join(deploy_path, "tsp_assistant.db")
if os.path.exists(db_file):
print("备份数据库...")
os.makedirs(os.path.join(backup_path, "database"), exist_ok=True)
shutil.copy2(db_file, os.path.join(backup_path, "database", "tsp_assistant.db"))
# 保存备份信息
backup_info = {
"backup_name": backup_name,
"backup_path": backup_path,
"timestamp": timestamp,
"environment": environment,
"version": self.version_manager.get_version() if self.version_manager else "unknown",
"git_commit": self._get_git_commit(deploy_path)
}
with open(os.path.join(backup_path, "backup_info.json"), 'w', encoding='utf-8') as f:
json.dump(backup_info, f, indent=2, ensure_ascii=False)
print(f"备份完成: {backup_name}")
return backup_name
def _get_git_commit(self, path: str) -> str:
"""获取Git提交哈希"""
try:
result = subprocess.run(['git', 'rev-parse', 'HEAD'],
cwd=path, capture_output=True, text=True)
return result.stdout.strip()[:8] if result.returncode == 0 else "unknown"
except:
return "unknown"
def hot_update(self, source_path: str, environment: str = "production") -> bool:
"""热更新(不重启服务)"""
if not self.config["hot_update_enabled"]:
print("热更新未启用")
return False
print("开始热更新...")
env_config = self.config["environments"].get(environment, {})
deploy_path = env_config.get("path", self.config["deploy_path"])
# 检查哪些文件可以热更新
hot_update_files = [
"src/web/static/js/dashboard.js",
"src/web/static/css/style.css",
"src/web/templates/dashboard.html",
"src/web/app.py",
"src/knowledge_base/knowledge_manager.py",
"src/dialogue/realtime_chat.py"
]
updated_files = []
for file_path in hot_update_files:
source_file = os.path.join(source_path, file_path)
target_file = os.path.join(deploy_path, file_path)
if os.path.exists(source_file):
# 检查文件是否有变化
if not os.path.exists(target_file) or not self._files_equal(source_file, target_file):
print(f"更新文件: {file_path}")
os.makedirs(os.path.dirname(target_file), exist_ok=True)
shutil.copy2(source_file, target_file)
updated_files.append(file_path)
if updated_files:
print(f"热更新完成,更新了 {len(updated_files)} 个文件")
return True
else:
print("没有文件需要热更新")
return False
def _files_equal(self, file1: str, file2: str) -> bool:
"""比较两个文件是否相等"""
try:
with open(file1, 'rb') as f1, open(file2, 'rb') as f2:
return f1.read() == f2.read()
except:
return False
def full_update(self, source_path: str, environment: str = "production",
create_backup: bool = True) -> bool:
"""完整更新(重启服务)"""
print("开始完整更新...")
env_config = self.config["environments"].get(environment, {})
deploy_path = env_config.get("path", self.config["deploy_path"])
service_name = env_config.get("service_name", self.config["service_name"])
auto_restart = env_config.get("auto_restart", True)
# 创建备份
backup_name = None
if create_backup and self.config["auto_backup"]:
backup_name = self.create_backup(environment)
try:
# 停止服务
if auto_restart and service_name:
print(f"停止服务: {service_name}")
subprocess.run(['sudo', 'systemctl', 'stop', service_name], check=True)
# 更新文件
print("更新应用文件...")
if os.path.exists(deploy_path):
shutil.rmtree(deploy_path)
os.makedirs(deploy_path, exist_ok=True)
shutil.copytree(source_path, deploy_path, dirs_exist_ok=True)
# 设置权限
subprocess.run(['sudo', 'chown', '-R', 'www-data:www-data', deploy_path], check=True)
# 安装依赖
print("安装依赖...")
requirements_file = os.path.join(deploy_path, "requirements.txt")
if os.path.exists(requirements_file):
subprocess.run(['sudo', '-u', 'www-data', 'python', '-m', 'pip', 'install', '-r', requirements_file],
cwd=deploy_path, check=True)
# 运行数据库迁移
print("运行数据库迁移...")
init_script = os.path.join(deploy_path, "init_database.py")
if os.path.exists(init_script):
subprocess.run(['sudo', '-u', 'www-data', 'python', init_script],
cwd=deploy_path, check=True)
# 启动服务
if auto_restart and service_name:
print(f"启动服务: {service_name}")
subprocess.run(['sudo', 'systemctl', 'start', service_name], check=True)
# 等待服务启动
print("等待服务启动...")
time.sleep(15)
# 健康检查
if self._health_check():
print("更新成功!")
return True
else:
print("健康检查失败,开始回滚...")
if backup_name:
self.rollback(backup_name, environment)
return False
else:
print("更新完成(未重启服务)")
return True
except Exception as e:
print(f"更新失败: {e}")
if backup_name:
print("开始回滚...")
self.rollback(backup_name, environment)
return False
def _health_check(self) -> bool:
"""健康检查"""
health_url = self.config["health_url"]
max_retries = 10
retry_count = 0
while retry_count < max_retries:
try:
response = requests.get(health_url, timeout=5)
if response.status_code == 200:
return True
except:
pass
retry_count += 1
print(f"健康检查失败,重试中... ({retry_count}/{max_retries})")
time.sleep(5)
return False
def rollback(self, backup_name: str, environment: str = "production") -> bool:
"""回滚到指定备份"""
print(f"开始回滚到备份: {backup_name}")
env_config = self.config["environments"].get(environment, {})
deploy_path = env_config.get("path", self.config["deploy_path"])
service_name = env_config.get("service_name", self.config["service_name"])
auto_restart = env_config.get("auto_restart", True)
backup_path = os.path.join(self.config["backup_path"], backup_name)
if not os.path.exists(backup_path):
print(f"备份不存在: {backup_name}")
return False
try:
# 停止服务
if auto_restart and service_name:
print(f"停止服务: {service_name}")
subprocess.run(['sudo', 'systemctl', 'stop', service_name], check=True)
# 恢复文件
print("恢复文件...")
app_backup_path = os.path.join(backup_path, "app")
if os.path.exists(app_backup_path):
if os.path.exists(deploy_path):
shutil.rmtree(deploy_path)
shutil.copytree(app_backup_path, deploy_path)
# 恢复数据库
db_backup_path = os.path.join(backup_path, "database", "tsp_assistant.db")
if os.path.exists(db_backup_path):
print("恢复数据库...")
shutil.copy2(db_backup_path, os.path.join(deploy_path, "tsp_assistant.db"))
# 设置权限
subprocess.run(['sudo', 'chown', '-R', 'www-data:www-data', deploy_path], check=True)
# 启动服务
if auto_restart and service_name:
print(f"启动服务: {service_name}")
subprocess.run(['sudo', 'systemctl', 'start', service_name], check=True)
# 等待服务启动
time.sleep(15)
# 健康检查
if self._health_check():
print("回滚成功!")
return True
else:
print("回滚后健康检查失败")
return False
else:
print("回滚完成(未重启服务)")
return True
except Exception as e:
print(f"回滚失败: {e}")
return False
def list_backups(self) -> List[Dict]:
"""列出所有备份"""
backups = []
backup_dir = self.config["backup_path"]
if os.path.exists(backup_dir):
for item in os.listdir(backup_dir):
backup_path = os.path.join(backup_dir, item)
if os.path.isdir(backup_path):
info_file = os.path.join(backup_path, "backup_info.json")
if os.path.exists(info_file):
try:
with open(info_file, 'r', encoding='utf-8') as f:
backup_info = json.load(f)
backups.append(backup_info)
except:
pass
return sorted(backups, key=lambda x: x.get("timestamp", ""), reverse=True)
def auto_update(self, source_path: str, environment: str = "production") -> bool:
"""自动更新(智能选择热更新或完整更新)"""
print("开始自动更新...")
# 检查是否有更新
has_update, current_version, new_version = self.check_update_available(source_path)
if not has_update:
print("没有更新可用")
return True
print(f"发现更新: {current_version} -> {new_version}")
# 尝试热更新
if self.hot_update(source_path, environment):
print("热更新成功")
return True
# 热更新失败,进行完整更新
print("热更新失败,进行完整更新...")
return self.full_update(source_path, environment)
def main():
"""命令行接口"""
import argparse
parser = argparse.ArgumentParser(description='TSP智能助手更新管理器')
parser.add_argument('action', choices=['check', 'hot-update', 'full-update', 'auto-update', 'rollback', 'list-backups'],
help='要执行的操作')
parser.add_argument('--source', help='源路径')
parser.add_argument('--environment', choices=['development', 'staging', 'production'],
default='production', help='目标环境')
parser.add_argument('--backup', help='备份名称(用于回滚)')
parser.add_argument('--no-backup', action='store_true', help='跳过备份')
args = parser.parse_args()
um = UpdateManager()
if args.action == 'check':
if not args.source:
print("错误: 需要指定源路径")
sys.exit(1)
has_update, current, new = um.check_update_available(args.source)
if has_update:
print(f"有更新可用: {current} -> {new}")
else:
print(f"没有更新可用 (当前版本: {current})")
elif args.action == 'hot-update':
if not args.source:
print("错误: 需要指定源路径")
sys.exit(1)
success = um.hot_update(args.source, args.environment)
sys.exit(0 if success else 1)
elif args.action == 'full-update':
if not args.source:
print("错误: 需要指定源路径")
sys.exit(1)
success = um.full_update(args.source, args.environment, not args.no_backup)
sys.exit(0 if success else 1)
elif args.action == 'auto-update':
if not args.source:
print("错误: 需要指定源路径")
sys.exit(1)
success = um.auto_update(args.source, args.environment)
sys.exit(0 if success else 1)
elif args.action == 'rollback':
if not args.backup:
print("错误: 需要指定备份名称")
sys.exit(1)
success = um.rollback(args.backup, args.environment)
sys.exit(0 if success else 1)
elif args.action == 'list-backups':
backups = um.list_backups()
if backups:
print("可用备份:")
for backup in backups:
print(f" {backup['backup_name']} - {backup['timestamp']} - {backup.get('version', 'unknown')}")
else:
print("没有找到备份")
if __name__ == "__main__":
main()