95 lines
2.8 KiB
Python
95 lines
2.8 KiB
Python
|
|
from sqlalchemy import create_engine, text
|
||
|
|
from sqlalchemy.orm import sessionmaker, Session
|
||
|
|
from sqlalchemy.pool import StaticPool
|
||
|
|
from contextlib import contextmanager
|
||
|
|
from typing import Generator
|
||
|
|
import logging
|
||
|
|
|
||
|
|
from .models import Base
|
||
|
|
from ..config.config import Config
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
class DatabaseManager:
|
||
|
|
"""数据库管理器"""
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
self.engine = None
|
||
|
|
self.SessionLocal = None
|
||
|
|
self._initialize_database()
|
||
|
|
|
||
|
|
def _initialize_database(self):
|
||
|
|
"""初始化数据库连接"""
|
||
|
|
try:
|
||
|
|
db_config = Config.get_database_config()
|
||
|
|
|
||
|
|
# 根据数据库类型选择不同的连接参数
|
||
|
|
if "mysql" in db_config["url"]:
|
||
|
|
# MySQL配置
|
||
|
|
self.engine = create_engine(
|
||
|
|
db_config["url"],
|
||
|
|
echo=db_config["echo"],
|
||
|
|
pool_size=10,
|
||
|
|
max_overflow=20,
|
||
|
|
pool_pre_ping=True,
|
||
|
|
pool_recycle=3600
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
# SQLite配置
|
||
|
|
self.engine = create_engine(
|
||
|
|
db_config["url"],
|
||
|
|
echo=db_config["echo"],
|
||
|
|
poolclass=StaticPool,
|
||
|
|
connect_args={"check_same_thread": False}
|
||
|
|
)
|
||
|
|
|
||
|
|
self.SessionLocal = sessionmaker(
|
||
|
|
autocommit=False,
|
||
|
|
autoflush=False,
|
||
|
|
bind=self.engine
|
||
|
|
)
|
||
|
|
|
||
|
|
# 创建所有表
|
||
|
|
Base.metadata.create_all(bind=self.engine)
|
||
|
|
logger.info("数据库初始化成功")
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"数据库初始化失败: {e}")
|
||
|
|
raise
|
||
|
|
|
||
|
|
@contextmanager
|
||
|
|
def get_session(self) -> Generator[Session, None, None]:
|
||
|
|
"""获取数据库会话的上下文管理器"""
|
||
|
|
session = self.SessionLocal()
|
||
|
|
try:
|
||
|
|
yield session
|
||
|
|
session.commit()
|
||
|
|
except Exception as e:
|
||
|
|
session.rollback()
|
||
|
|
logger.error(f"数据库操作失败: {e}")
|
||
|
|
raise
|
||
|
|
finally:
|
||
|
|
session.close()
|
||
|
|
|
||
|
|
def get_session_direct(self) -> Session:
|
||
|
|
"""直接获取数据库会话"""
|
||
|
|
return self.SessionLocal()
|
||
|
|
|
||
|
|
def close_session(self, session: Session):
|
||
|
|
"""关闭数据库会话"""
|
||
|
|
if session:
|
||
|
|
session.close()
|
||
|
|
|
||
|
|
def test_connection(self) -> bool:
|
||
|
|
"""测试数据库连接"""
|
||
|
|
try:
|
||
|
|
with self.get_session() as session:
|
||
|
|
session.execute(text("SELECT 1"))
|
||
|
|
return True
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"数据库连接测试失败: {e}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
# 全局数据库管理器实例
|
||
|
|
db_manager = DatabaseManager()
|