122 lines
5.0 KiB
Python
122 lines
5.0 KiB
Python
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
"""
|
|||
|
|
Repository 基类
|
|||
|
|
所有数据访问通过 Repository 层,自动附加 tenant_id 过滤。
|
|||
|
|
"""
|
|||
|
|
import logging
|
|||
|
|
from typing import Any, Dict, List, Optional, Type
|
|||
|
|
from sqlalchemy.orm import Session
|
|||
|
|
from src.core.database import db_manager
|
|||
|
|
from src.core.models import DEFAULT_TENANT
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class BaseRepository:
|
|||
|
|
"""
|
|||
|
|
Repository 基类。子类只需指定 model_class 和 tenant_field。
|
|||
|
|
所有查询自动按 tenant_id 过滤(如果模型有该字段)。
|
|||
|
|
"""
|
|||
|
|
model_class = None # 子类必须设置
|
|||
|
|
tenant_field = 'tenant_id' # 默认租户字段名
|
|||
|
|
|
|||
|
|
def _base_query(self, session: Session, tenant_id: str = None):
|
|||
|
|
"""构建带 tenant_id 过滤的基础查询"""
|
|||
|
|
q = session.query(self.model_class)
|
|||
|
|
if tenant_id and hasattr(self.model_class, self.tenant_field):
|
|||
|
|
q = q.filter(getattr(self.model_class, self.tenant_field) == tenant_id)
|
|||
|
|
return q
|
|||
|
|
|
|||
|
|
def get_by_id(self, id: int, tenant_id: str = None) -> Optional[Dict]:
|
|||
|
|
"""按 ID 查询"""
|
|||
|
|
with db_manager.get_session() as session:
|
|||
|
|
q = self._base_query(session, tenant_id).filter(self.model_class.id == id)
|
|||
|
|
obj = q.first()
|
|||
|
|
return self._to_dict(obj) if obj else None
|
|||
|
|
|
|||
|
|
def list(self, tenant_id: str = None, page: int = 1, per_page: int = 20,
|
|||
|
|
filters: Dict = None, order_by=None) -> Dict[str, Any]:
|
|||
|
|
"""分页列表查询"""
|
|||
|
|
with db_manager.get_session() as session:
|
|||
|
|
q = self._base_query(session, tenant_id)
|
|||
|
|
if filters:
|
|||
|
|
for field, value in filters.items():
|
|||
|
|
if value is not None and hasattr(self.model_class, field):
|
|||
|
|
q = q.filter(getattr(self.model_class, field) == value)
|
|||
|
|
if order_by is not None:
|
|||
|
|
q = q.order_by(order_by)
|
|||
|
|
total = q.count()
|
|||
|
|
items = q.offset((page - 1) * per_page).limit(per_page).all()
|
|||
|
|
return {
|
|||
|
|
'items': [self._to_dict(item) for item in items],
|
|||
|
|
'page': page, 'per_page': per_page,
|
|||
|
|
'total': total, 'total_pages': (total + per_page - 1) // per_page
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def create(self, data: Dict, tenant_id: str = None) -> Dict:
|
|||
|
|
"""创建记录,自动设置 tenant_id"""
|
|||
|
|
with db_manager.get_session() as session:
|
|||
|
|
if tenant_id and hasattr(self.model_class, self.tenant_field):
|
|||
|
|
data[self.tenant_field] = tenant_id
|
|||
|
|
elif hasattr(self.model_class, self.tenant_field) and self.tenant_field not in data:
|
|||
|
|
data[self.tenant_field] = DEFAULT_TENANT
|
|||
|
|
# 只保留模型有的字段
|
|||
|
|
valid = {k: v for k, v in data.items() if hasattr(self.model_class, k) and not isinstance(v, (dict, list))}
|
|||
|
|
obj = self.model_class(**valid)
|
|||
|
|
session.add(obj)
|
|||
|
|
session.flush()
|
|||
|
|
result = self._to_dict(obj)
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
def update(self, id: int, data: Dict, tenant_id: str = None) -> Optional[Dict]:
|
|||
|
|
"""更新记录"""
|
|||
|
|
with db_manager.get_session() as session:
|
|||
|
|
q = self._base_query(session, tenant_id).filter(self.model_class.id == id)
|
|||
|
|
obj = q.first()
|
|||
|
|
if not obj:
|
|||
|
|
return None
|
|||
|
|
for k, v in data.items():
|
|||
|
|
if hasattr(obj, k) and k not in ('id', 'tenant_id'):
|
|||
|
|
setattr(obj, k, v)
|
|||
|
|
session.flush()
|
|||
|
|
return self._to_dict(obj)
|
|||
|
|
|
|||
|
|
def delete(self, id: int, tenant_id: str = None) -> bool:
|
|||
|
|
"""删除记录"""
|
|||
|
|
with db_manager.get_session() as session:
|
|||
|
|
q = self._base_query(session, tenant_id).filter(self.model_class.id == id)
|
|||
|
|
obj = q.first()
|
|||
|
|
if not obj:
|
|||
|
|
return False
|
|||
|
|
session.delete(obj)
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
def batch_delete(self, ids: List[int], tenant_id: str = None) -> int:
|
|||
|
|
"""批量删除,返回实际删除数量"""
|
|||
|
|
with db_manager.get_session() as session:
|
|||
|
|
q = self._base_query(session, tenant_id).filter(self.model_class.id.in_(ids))
|
|||
|
|
count = q.delete(synchronize_session='fetch')
|
|||
|
|
return count
|
|||
|
|
|
|||
|
|
def count(self, tenant_id: str = None, filters: Dict = None) -> int:
|
|||
|
|
"""计数"""
|
|||
|
|
with db_manager.get_session() as session:
|
|||
|
|
q = self._base_query(session, tenant_id)
|
|||
|
|
if filters:
|
|||
|
|
for field, value in filters.items():
|
|||
|
|
if value is not None and hasattr(self.model_class, field):
|
|||
|
|
q = q.filter(getattr(self.model_class, field) == value)
|
|||
|
|
return q.count()
|
|||
|
|
|
|||
|
|
def _to_dict(self, obj) -> Dict:
|
|||
|
|
"""将 ORM 对象转为字典。子类可覆盖。"""
|
|||
|
|
if hasattr(obj, 'to_dict'):
|
|||
|
|
return obj.to_dict()
|
|||
|
|
result = {}
|
|||
|
|
for col in obj.__table__.columns:
|
|||
|
|
val = getattr(obj, col.name)
|
|||
|
|
if hasattr(val, 'isoformat'):
|
|||
|
|
val = val.isoformat()
|
|||
|
|
result[col.name] = val
|
|||
|
|
return result
|