"""数据访问层 - 提供隐私保护的数据访问接口。""" import pandas as pd import logging from typing import Dict, Any, List, Optional from pathlib import Path from src.models import DataProfile, ColumnInfo logger = logging.getLogger(__name__) class DataLoadError(Exception): """数据加载错误。""" pass class DataAccessLayer: """ 数据访问层,提供隐私保护的数据访问。 核心原则: - AI 不能直接访问原始数据 - 只能通过工具获取聚合结果 - 数据画像只包含元数据和统计摘要 """ def __init__(self, data: pd.DataFrame, file_path: str = ""): """ 初始化数据访问层。 参数: data: 原始数据(私有,不暴露给 AI) file_path: 数据文件路径 """ self._data = data # 私有数据,AI 不可访问 self._file_path = file_path @classmethod def load_from_file(cls, file_path: str, max_retries: int = 3, optimize_memory: bool = True) -> 'DataAccessLayer': """ 从文件加载数据,支持多种编码和性能优化。 参数: file_path: CSV 文件路径 max_retries: 最大重试次数 optimize_memory: 是否优化内存使用 返回: DataAccessLayer 实例 异常: DataLoadError: 数据加载失败 """ encodings = ['utf-8', 'gbk', 'gb2312', 'latin1', 'iso-8859-1'] for encoding in encodings: try: logger.info(f"尝试使用编码 {encoding} 加载文件: {file_path}") # 使用低内存模式加载大文件 data = pd.read_csv(file_path, encoding=encoding, low_memory=False) # 检查数据是否为空 if data.empty: raise DataLoadError(f"文件 {file_path} 为空") # 检查数据大小并采样 if len(data) > 1_000_000: logger.warning(f"数据过大({len(data)}行),采样到100万行") data = data.sample(n=1_000_000, random_state=42) # 优化内存使用 if optimize_memory: from src.performance_optimization import DataLoadOptimizer initial_memory = data.memory_usage(deep=True).sum() / 1024 / 1024 data = DataLoadOptimizer.optimize_dtypes(data) final_memory = data.memory_usage(deep=True).sum() / 1024 / 1024 logger.info(f"内存优化: {initial_memory:.2f}MB -> {final_memory:.2f}MB (节省 {initial_memory - final_memory:.2f}MB)") logger.info(f"成功加载数据: {len(data)}行, {len(data.columns)}列") return cls(data, file_path) except UnicodeDecodeError: logger.debug(f"编码 {encoding} 失败,尝试下一个") continue except Exception as e: logger.error(f"加载文件失败 ({encoding}): {e}") if encoding == encodings[-1]: raise DataLoadError(f"无法加载文件 {file_path}: {e}") continue raise DataLoadError(f"无法加载文件 {file_path},尝试了所有编码") def get_profile(self) -> DataProfile: """ 生成数据画像(安全,不包含原始数据)。 返回: DataProfile: 数据画像,包含元数据和统计摘要 """ columns_info = [] for col_name in self._data.columns: col_data = self._data[col_name] # 推断数据类型 dtype = self._infer_column_type(col_data) # 计算缺失率 missing_rate = col_data.isna().sum() / len(col_data) # 计算唯一值数量 unique_count = col_data.nunique() # 获取示例值(最多5个,去重) sample_values = col_data.dropna().unique()[:5].tolist() # 计算统计信息 statistics = {} if dtype == 'numeric': statistics = { 'min': float(col_data.min()) if not col_data.isna().all() else None, 'max': float(col_data.max()) if not col_data.isna().all() else None, 'mean': float(col_data.mean()) if not col_data.isna().all() else None, 'std': float(col_data.std()) if not col_data.isna().all() else None, 'median': float(col_data.median()) if not col_data.isna().all() else None, } elif dtype == 'categorical': value_counts = col_data.value_counts().head(10) statistics = { 'top_values': value_counts.to_dict(), 'num_categories': unique_count, } columns_info.append(ColumnInfo( name=col_name, dtype=dtype, missing_rate=float(missing_rate), unique_count=int(unique_count), sample_values=sample_values, statistics=statistics )) return DataProfile( file_path=self._file_path, row_count=len(self._data), column_count=len(self._data.columns), columns=columns_info, inferred_type='unknown', # 将由 AI 推断 key_fields={}, quality_score=0.0, summary="" ) def _infer_column_type(self, col_data: pd.Series) -> str: """ 推断列的数据类型。 参数: col_data: 列数据 返回: 数据类型: 'numeric', 'categorical', 'datetime', 'text' """ # 检查是否为日期时间类型 if pd.api.types.is_datetime64_any_dtype(col_data): return 'datetime' # 尝试转换为日期时间 if col_data.dtype == 'object': try: sample = col_data.dropna().head(20) if len(sample) == 0: pass else: # 尝试用常见日期格式解析 date_formats = ['%Y-%m-%d', '%Y/%m/%d', '%Y-%m-%d %H:%M:%S', '%Y/%m/%d %H:%M:%S', '%d/%m/%Y', '%m/%d/%Y'] parsed = False for fmt in date_formats: try: pd.to_datetime(sample, format=fmt) parsed = True break except (ValueError, TypeError): continue if not parsed: # 最后尝试自动推断,但用 infer_datetime_format pd.to_datetime(sample, format='mixed', dayfirst=False) parsed = True if parsed: return 'datetime' except: pass # 检查是否为数值类型 if pd.api.types.is_numeric_dtype(col_data): return 'numeric' # 检查是否为分类类型(唯一值较少) unique_ratio = col_data.nunique() / len(col_data) if unique_ratio < 0.05 or col_data.nunique() < 20: return 'categorical' # 默认为文本类型 return 'text' def execute_tool(self, tool: Any, **kwargs) -> Dict[str, Any]: """ 执行工具并返回聚合结果(安全)。 参数: tool: 分析工具实例 **kwargs: 工具参数 返回: 工具执行结果(聚合数据) """ try: result = tool.execute(self._data, **kwargs) return self._sanitize_result(result) except Exception as e: logger.error(f"工具 {tool.name} 执行失败: {e}") return { 'success': False, 'error': str(e), 'tool': tool.name } def _sanitize_result(self, result: Dict[str, Any]) -> Dict[str, Any]: """ 确保结果不包含原始数据,只返回聚合数据。 参数: result: 工具执行结果 返回: 过滤后的结果 """ # 检查结果中是否有 DataFrame sanitized = {} for key, value in result.items(): if isinstance(value, pd.DataFrame): # 限制返回的行数 if len(value) > 100: logger.warning(f"结果包含 {len(value)} 行数据,截断到100行") value = value.head(100) sanitized[key] = value.to_dict('records') elif isinstance(value, pd.Series): # 限制返回的行数 if len(value) > 100: logger.warning(f"结果包含 {len(value)} 行数据,截断到100行") value = value.head(100) sanitized[key] = value.to_dict() else: sanitized[key] = value return sanitized @property def shape(self) -> tuple: """返回数据形状(行数,列数)。""" return self._data.shape @property def columns(self) -> List[str]: """返回列名列表。""" return self._data.columns.tolist()