from sqlalchemy import create_engine, text from sqlalchemy.orm import sessionmaker, Session from sqlalchemy.pool import StaticPool from contextlib import contextmanager from typing import Generator import logging from .models import Base from ..config.config import Config logger = logging.getLogger(__name__) class DatabaseManager: """数据库管理器""" def __init__(self): self.engine = None self.SessionLocal = None self._initialize_database() def _initialize_database(self): """初始化数据库连接""" try: db_config = Config.get_database_config() # 根据数据库类型选择不同的连接参数 if "mysql" in db_config["url"]: # MySQL配置 self.engine = create_engine( db_config["url"], echo=db_config["echo"], pool_size=10, max_overflow=20, pool_pre_ping=True, pool_recycle=3600 ) else: # SQLite配置 self.engine = create_engine( db_config["url"], echo=db_config["echo"], poolclass=StaticPool, connect_args={"check_same_thread": False} ) self.SessionLocal = sessionmaker( autocommit=False, autoflush=False, bind=self.engine ) # 创建所有表 Base.metadata.create_all(bind=self.engine) logger.info("数据库初始化成功") except Exception as e: logger.error(f"数据库初始化失败: {e}") raise @contextmanager def get_session(self) -> Generator[Session, None, None]: """获取数据库会话的上下文管理器""" session = self.SessionLocal() try: yield session session.commit() except Exception as e: session.rollback() logger.error(f"数据库操作失败: {e}") raise finally: session.close() def get_session_direct(self) -> Session: """直接获取数据库会话""" return self.SessionLocal() def close_session(self, session: Session): """关闭数据库会话""" if session: session.close() def test_connection(self) -> bool: """测试数据库连接""" try: with self.get_session() as session: session.execute(text("SELECT 1")) return True except Exception as e: logger.error(f"数据库连接测试失败: {e}") return False # 全局数据库管理器实例 db_manager = DatabaseManager()