Files
assist/src/repositories/base.py

122 lines
5.0 KiB
Python
Raw Normal View History

# -*- coding: utf-8 -*-
"""
Repository 基类
所有数据访问通过 Repository 自动附加 tenant_id 过滤
"""
import logging
from typing import Any, Dict, List, Optional, Type
from sqlalchemy.orm import Session
from src.core.database import db_manager
from src.core.models import DEFAULT_TENANT
logger = logging.getLogger(__name__)
class BaseRepository:
"""
Repository 基类子类只需指定 model_class tenant_field
所有查询自动按 tenant_id 过滤如果模型有该字段
"""
model_class = None # 子类必须设置
tenant_field = 'tenant_id' # 默认租户字段名
def _base_query(self, session: Session, tenant_id: str = None):
"""构建带 tenant_id 过滤的基础查询"""
q = session.query(self.model_class)
if tenant_id and hasattr(self.model_class, self.tenant_field):
q = q.filter(getattr(self.model_class, self.tenant_field) == tenant_id)
return q
def get_by_id(self, id: int, tenant_id: str = None) -> Optional[Dict]:
"""按 ID 查询"""
with db_manager.get_session() as session:
q = self._base_query(session, tenant_id).filter(self.model_class.id == id)
obj = q.first()
return self._to_dict(obj) if obj else None
def list(self, tenant_id: str = None, page: int = 1, per_page: int = 20,
filters: Dict = None, order_by=None) -> Dict[str, Any]:
"""分页列表查询"""
with db_manager.get_session() as session:
q = self._base_query(session, tenant_id)
if filters:
for field, value in filters.items():
if value is not None and hasattr(self.model_class, field):
q = q.filter(getattr(self.model_class, field) == value)
if order_by is not None:
q = q.order_by(order_by)
total = q.count()
items = q.offset((page - 1) * per_page).limit(per_page).all()
return {
'items': [self._to_dict(item) for item in items],
'page': page, 'per_page': per_page,
'total': total, 'total_pages': (total + per_page - 1) // per_page
}
def create(self, data: Dict, tenant_id: str = None) -> Dict:
"""创建记录,自动设置 tenant_id"""
with db_manager.get_session() as session:
if tenant_id and hasattr(self.model_class, self.tenant_field):
data[self.tenant_field] = tenant_id
elif hasattr(self.model_class, self.tenant_field) and self.tenant_field not in data:
data[self.tenant_field] = DEFAULT_TENANT
# 只保留模型有的字段
valid = {k: v for k, v in data.items() if hasattr(self.model_class, k) and not isinstance(v, (dict, list))}
obj = self.model_class(**valid)
session.add(obj)
session.flush()
result = self._to_dict(obj)
return result
def update(self, id: int, data: Dict, tenant_id: str = None) -> Optional[Dict]:
"""更新记录"""
with db_manager.get_session() as session:
q = self._base_query(session, tenant_id).filter(self.model_class.id == id)
obj = q.first()
if not obj:
return None
for k, v in data.items():
if hasattr(obj, k) and k not in ('id', 'tenant_id'):
setattr(obj, k, v)
session.flush()
return self._to_dict(obj)
def delete(self, id: int, tenant_id: str = None) -> bool:
"""删除记录"""
with db_manager.get_session() as session:
q = self._base_query(session, tenant_id).filter(self.model_class.id == id)
obj = q.first()
if not obj:
return False
session.delete(obj)
return True
def batch_delete(self, ids: List[int], tenant_id: str = None) -> int:
"""批量删除,返回实际删除数量"""
with db_manager.get_session() as session:
q = self._base_query(session, tenant_id).filter(self.model_class.id.in_(ids))
count = q.delete(synchronize_session='fetch')
return count
def count(self, tenant_id: str = None, filters: Dict = None) -> int:
"""计数"""
with db_manager.get_session() as session:
q = self._base_query(session, tenant_id)
if filters:
for field, value in filters.items():
if value is not None and hasattr(self.model_class, field):
q = q.filter(getattr(self.model_class, field) == value)
return q.count()
def _to_dict(self, obj) -> Dict:
"""将 ORM 对象转为字典。子类可覆盖。"""
if hasattr(obj, 'to_dict'):
return obj.to_dict()
result = {}
for col in obj.__table__.columns:
val = getattr(obj, col.name)
if hasattr(val, 'isoformat'):
val = val.isoformat()
result[col.name] = val
return result