# -*- coding: utf-8 -*- """ Repository 基类 所有数据访问通过 Repository 层,自动附加 tenant_id 过滤。 """ import logging from typing import Any, Dict, List, Optional, Type from sqlalchemy.orm import Session from src.core.database import db_manager from src.core.models import DEFAULT_TENANT logger = logging.getLogger(__name__) class BaseRepository: """ Repository 基类。子类只需指定 model_class 和 tenant_field。 所有查询自动按 tenant_id 过滤(如果模型有该字段)。 """ model_class = None # 子类必须设置 tenant_field = 'tenant_id' # 默认租户字段名 def _base_query(self, session: Session, tenant_id: str = None): """构建带 tenant_id 过滤的基础查询""" q = session.query(self.model_class) if tenant_id and hasattr(self.model_class, self.tenant_field): q = q.filter(getattr(self.model_class, self.tenant_field) == tenant_id) return q def get_by_id(self, id: int, tenant_id: str = None) -> Optional[Dict]: """按 ID 查询""" with db_manager.get_session() as session: q = self._base_query(session, tenant_id).filter(self.model_class.id == id) obj = q.first() return self._to_dict(obj) if obj else None def list(self, tenant_id: str = None, page: int = 1, per_page: int = 20, filters: Dict = None, order_by=None) -> Dict[str, Any]: """分页列表查询""" with db_manager.get_session() as session: q = self._base_query(session, tenant_id) if filters: for field, value in filters.items(): if value is not None and hasattr(self.model_class, field): q = q.filter(getattr(self.model_class, field) == value) if order_by is not None: q = q.order_by(order_by) total = q.count() items = q.offset((page - 1) * per_page).limit(per_page).all() return { 'items': [self._to_dict(item) for item in items], 'page': page, 'per_page': per_page, 'total': total, 'total_pages': (total + per_page - 1) // per_page } def create(self, data: Dict, tenant_id: str = None) -> Dict: """创建记录,自动设置 tenant_id""" with db_manager.get_session() as session: if tenant_id and hasattr(self.model_class, self.tenant_field): data[self.tenant_field] = tenant_id elif hasattr(self.model_class, self.tenant_field) and self.tenant_field not in data: data[self.tenant_field] = DEFAULT_TENANT # 只保留模型有的字段 valid = {k: v for k, v in data.items() if hasattr(self.model_class, k) and not isinstance(v, (dict, list))} obj = self.model_class(**valid) session.add(obj) session.flush() result = self._to_dict(obj) return result def update(self, id: int, data: Dict, tenant_id: str = None) -> Optional[Dict]: """更新记录""" with db_manager.get_session() as session: q = self._base_query(session, tenant_id).filter(self.model_class.id == id) obj = q.first() if not obj: return None for k, v in data.items(): if hasattr(obj, k) and k not in ('id', 'tenant_id'): setattr(obj, k, v) session.flush() return self._to_dict(obj) def delete(self, id: int, tenant_id: str = None) -> bool: """删除记录""" with db_manager.get_session() as session: q = self._base_query(session, tenant_id).filter(self.model_class.id == id) obj = q.first() if not obj: return False session.delete(obj) return True def batch_delete(self, ids: List[int], tenant_id: str = None) -> int: """批量删除,返回实际删除数量""" with db_manager.get_session() as session: q = self._base_query(session, tenant_id).filter(self.model_class.id.in_(ids)) count = q.delete(synchronize_session='fetch') return count def count(self, tenant_id: str = None, filters: Dict = None) -> int: """计数""" with db_manager.get_session() as session: q = self._base_query(session, tenant_id) if filters: for field, value in filters.items(): if value is not None and hasattr(self.model_class, field): q = q.filter(getattr(self.model_class, field) == value) return q.count() def _to_dict(self, obj) -> Dict: """将 ORM 对象转为字典。子类可覆盖。""" if hasattr(obj, 'to_dict'): return obj.to_dict() result = {} for col in obj.__table__.columns: val = getattr(obj, col.name) if hasattr(val, 'isoformat'): val = val.isoformat() result[col.name] = val return result