safe: Repository 层 + 旧代码清理 + tasks.md(从 v2.0 分支安全提取)
This commit is contained in:
121
src/repositories/base.py
Normal file
121
src/repositories/base.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Repository 基类
|
||||
所有数据访问通过 Repository 层,自动附加 tenant_id 过滤。
|
||||
"""
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
from sqlalchemy.orm import Session
|
||||
from src.core.database import db_manager
|
||||
from src.core.models import DEFAULT_TENANT
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseRepository:
|
||||
"""
|
||||
Repository 基类。子类只需指定 model_class 和 tenant_field。
|
||||
所有查询自动按 tenant_id 过滤(如果模型有该字段)。
|
||||
"""
|
||||
model_class = None # 子类必须设置
|
||||
tenant_field = 'tenant_id' # 默认租户字段名
|
||||
|
||||
def _base_query(self, session: Session, tenant_id: str = None):
|
||||
"""构建带 tenant_id 过滤的基础查询"""
|
||||
q = session.query(self.model_class)
|
||||
if tenant_id and hasattr(self.model_class, self.tenant_field):
|
||||
q = q.filter(getattr(self.model_class, self.tenant_field) == tenant_id)
|
||||
return q
|
||||
|
||||
def get_by_id(self, id: int, tenant_id: str = None) -> Optional[Dict]:
|
||||
"""按 ID 查询"""
|
||||
with db_manager.get_session() as session:
|
||||
q = self._base_query(session, tenant_id).filter(self.model_class.id == id)
|
||||
obj = q.first()
|
||||
return self._to_dict(obj) if obj else None
|
||||
|
||||
def list(self, tenant_id: str = None, page: int = 1, per_page: int = 20,
|
||||
filters: Dict = None, order_by=None) -> Dict[str, Any]:
|
||||
"""分页列表查询"""
|
||||
with db_manager.get_session() as session:
|
||||
q = self._base_query(session, tenant_id)
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
if value is not None and hasattr(self.model_class, field):
|
||||
q = q.filter(getattr(self.model_class, field) == value)
|
||||
if order_by is not None:
|
||||
q = q.order_by(order_by)
|
||||
total = q.count()
|
||||
items = q.offset((page - 1) * per_page).limit(per_page).all()
|
||||
return {
|
||||
'items': [self._to_dict(item) for item in items],
|
||||
'page': page, 'per_page': per_page,
|
||||
'total': total, 'total_pages': (total + per_page - 1) // per_page
|
||||
}
|
||||
|
||||
def create(self, data: Dict, tenant_id: str = None) -> Dict:
|
||||
"""创建记录,自动设置 tenant_id"""
|
||||
with db_manager.get_session() as session:
|
||||
if tenant_id and hasattr(self.model_class, self.tenant_field):
|
||||
data[self.tenant_field] = tenant_id
|
||||
elif hasattr(self.model_class, self.tenant_field) and self.tenant_field not in data:
|
||||
data[self.tenant_field] = DEFAULT_TENANT
|
||||
# 只保留模型有的字段
|
||||
valid = {k: v for k, v in data.items() if hasattr(self.model_class, k) and not isinstance(v, (dict, list))}
|
||||
obj = self.model_class(**valid)
|
||||
session.add(obj)
|
||||
session.flush()
|
||||
result = self._to_dict(obj)
|
||||
return result
|
||||
|
||||
def update(self, id: int, data: Dict, tenant_id: str = None) -> Optional[Dict]:
|
||||
"""更新记录"""
|
||||
with db_manager.get_session() as session:
|
||||
q = self._base_query(session, tenant_id).filter(self.model_class.id == id)
|
||||
obj = q.first()
|
||||
if not obj:
|
||||
return None
|
||||
for k, v in data.items():
|
||||
if hasattr(obj, k) and k not in ('id', 'tenant_id'):
|
||||
setattr(obj, k, v)
|
||||
session.flush()
|
||||
return self._to_dict(obj)
|
||||
|
||||
def delete(self, id: int, tenant_id: str = None) -> bool:
|
||||
"""删除记录"""
|
||||
with db_manager.get_session() as session:
|
||||
q = self._base_query(session, tenant_id).filter(self.model_class.id == id)
|
||||
obj = q.first()
|
||||
if not obj:
|
||||
return False
|
||||
session.delete(obj)
|
||||
return True
|
||||
|
||||
def batch_delete(self, ids: List[int], tenant_id: str = None) -> int:
|
||||
"""批量删除,返回实际删除数量"""
|
||||
with db_manager.get_session() as session:
|
||||
q = self._base_query(session, tenant_id).filter(self.model_class.id.in_(ids))
|
||||
count = q.delete(synchronize_session='fetch')
|
||||
return count
|
||||
|
||||
def count(self, tenant_id: str = None, filters: Dict = None) -> int:
|
||||
"""计数"""
|
||||
with db_manager.get_session() as session:
|
||||
q = self._base_query(session, tenant_id)
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
if value is not None and hasattr(self.model_class, field):
|
||||
q = q.filter(getattr(self.model_class, field) == value)
|
||||
return q.count()
|
||||
|
||||
def _to_dict(self, obj) -> Dict:
|
||||
"""将 ORM 对象转为字典。子类可覆盖。"""
|
||||
if hasattr(obj, 'to_dict'):
|
||||
return obj.to_dict()
|
||||
result = {}
|
||||
for col in obj.__table__.columns:
|
||||
val = getattr(obj, col.name)
|
||||
if hasattr(val, 'isoformat'):
|
||||
val = val.isoformat()
|
||||
result[col.name] = val
|
||||
return result
|
||||
Reference in New Issue
Block a user