from sqlalchemy import create_engine, text from sqlalchemy.orm import sessionmaker, Session from sqlalchemy.pool import StaticPool from contextlib import contextmanager from typing import Generator import logging from .models import Base from .cache_manager import cache_manager, cache_query from ..config.config import Config logger = logging.getLogger(__name__) class DatabaseManager: """数据库管理器""" def __init__(self): self.engine = None self.SessionLocal = None self._initialize_database() def _initialize_database(self): """初始化数据库连接""" try: db_config = Config.get_database_config() # 根据数据库类型选择不同的连接参数 if "mysql" in db_config["url"]: # MySQL配置 - 优化连接池 self.engine = create_engine( db_config["url"], echo=db_config["echo"], pool_size=20, # 增加连接池大小 max_overflow=30, # 增加溢出连接数 pool_pre_ping=True, pool_recycle=1800, # 减少回收时间 pool_timeout=10, # 连接超时 connect_args={ "charset": "utf8mb4", "autocommit": False } ) else: # SQLite配置 - 优化性能 self.engine = create_engine( db_config["url"], echo=db_config["echo"], poolclass=StaticPool, connect_args={ "check_same_thread": False, "timeout": 20, # 连接超时 "isolation_level": None # 自动提交模式 } ) self.SessionLocal = sessionmaker( autocommit=False, autoflush=False, bind=self.engine ) # 创建所有表 Base.metadata.create_all(bind=self.engine) logger.info("数据库初始化成功") except Exception as e: logger.error(f"数据库初始化失败: {e}") raise @contextmanager def get_session(self) -> Generator[Session, None, None]: """获取数据库会话的上下文管理器""" session = self.SessionLocal() try: yield session session.commit() except Exception as e: session.rollback() logger.error(f"数据库操作失败: {e}") raise finally: session.close() def get_session_direct(self) -> Session: """直接获取数据库会话""" return self.SessionLocal() def close_session(self, session: Session): """关闭数据库会话""" if session: session.close() def test_connection(self) -> bool: """测试数据库连接""" try: with self.get_session() as session: session.execute(text("SELECT 1")) return True except Exception as e: logger.error(f"数据库连接测试失败: {e}") return False @cache_query(ttl=60) # 缓存1分钟 def get_cached_query(self, query_key: str, query_func, *args, **kwargs): """执行带缓存的查询""" return query_func(*args, **kwargs) def invalidate_cache_pattern(self, pattern: str): """根据模式清除缓存""" try: cache_manager.delete(pattern) logger.info(f"缓存已清除: {pattern}") except Exception as e: logger.error(f"清除缓存失败: {e}") def get_cache_stats(self): """获取缓存统计信息""" return cache_manager.get_stats() # 全局数据库管理器实例 db_manager = DatabaseManager()