251 lines
8.8 KiB
Python
251 lines
8.8 KiB
Python
|
|
"""数据访问层 - 提供隐私保护的数据访问接口。"""
|
|||
|
|
|
|||
|
|
import pandas as pd
|
|||
|
|
import logging
|
|||
|
|
from typing import Dict, Any, List, Optional
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
from src.models import DataProfile, ColumnInfo
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class DataLoadError(Exception):
|
|||
|
|
"""数据加载错误。"""
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
|
|||
|
|
class DataAccessLayer:
|
|||
|
|
"""
|
|||
|
|
数据访问层,提供隐私保护的数据访问。
|
|||
|
|
|
|||
|
|
核心原则:
|
|||
|
|
- AI 不能直接访问原始数据
|
|||
|
|
- 只能通过工具获取聚合结果
|
|||
|
|
- 数据画像只包含元数据和统计摘要
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, data: pd.DataFrame, file_path: str = ""):
|
|||
|
|
"""
|
|||
|
|
初始化数据访问层。
|
|||
|
|
|
|||
|
|
参数:
|
|||
|
|
data: 原始数据(私有,不暴露给 AI)
|
|||
|
|
file_path: 数据文件路径
|
|||
|
|
"""
|
|||
|
|
self._data = data # 私有数据,AI 不可访问
|
|||
|
|
self._file_path = file_path
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def load_from_file(cls, file_path: str, max_retries: int = 3, optimize_memory: bool = True) -> 'DataAccessLayer':
|
|||
|
|
"""
|
|||
|
|
从文件加载数据,支持多种编码和性能优化。
|
|||
|
|
|
|||
|
|
参数:
|
|||
|
|
file_path: CSV 文件路径
|
|||
|
|
max_retries: 最大重试次数
|
|||
|
|
optimize_memory: 是否优化内存使用
|
|||
|
|
|
|||
|
|
返回:
|
|||
|
|
DataAccessLayer 实例
|
|||
|
|
|
|||
|
|
异常:
|
|||
|
|
DataLoadError: 数据加载失败
|
|||
|
|
"""
|
|||
|
|
encodings = ['utf-8', 'gbk', 'gb2312', 'latin1', 'iso-8859-1']
|
|||
|
|
|
|||
|
|
for encoding in encodings:
|
|||
|
|
try:
|
|||
|
|
logger.info(f"尝试使用编码 {encoding} 加载文件: {file_path}")
|
|||
|
|
|
|||
|
|
# 使用低内存模式加载大文件
|
|||
|
|
data = pd.read_csv(file_path, encoding=encoding, low_memory=False)
|
|||
|
|
|
|||
|
|
# 检查数据是否为空
|
|||
|
|
if data.empty:
|
|||
|
|
raise DataLoadError(f"文件 {file_path} 为空")
|
|||
|
|
|
|||
|
|
# 检查数据大小并采样
|
|||
|
|
if len(data) > 1_000_000:
|
|||
|
|
logger.warning(f"数据过大({len(data)}行),采样到100万行")
|
|||
|
|
data = data.sample(n=1_000_000, random_state=42)
|
|||
|
|
|
|||
|
|
# 优化内存使用
|
|||
|
|
if optimize_memory:
|
|||
|
|
from src.performance_optimization import DataLoadOptimizer
|
|||
|
|
initial_memory = data.memory_usage(deep=True).sum() / 1024 / 1024
|
|||
|
|
data = DataLoadOptimizer.optimize_dtypes(data)
|
|||
|
|
final_memory = data.memory_usage(deep=True).sum() / 1024 / 1024
|
|||
|
|
logger.info(f"内存优化: {initial_memory:.2f}MB -> {final_memory:.2f}MB (节省 {initial_memory - final_memory:.2f}MB)")
|
|||
|
|
|
|||
|
|
logger.info(f"成功加载数据: {len(data)}行, {len(data.columns)}列")
|
|||
|
|
return cls(data, file_path)
|
|||
|
|
|
|||
|
|
except UnicodeDecodeError:
|
|||
|
|
logger.debug(f"编码 {encoding} 失败,尝试下一个")
|
|||
|
|
continue
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"加载文件失败 ({encoding}): {e}")
|
|||
|
|
if encoding == encodings[-1]:
|
|||
|
|
raise DataLoadError(f"无法加载文件 {file_path}: {e}")
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
raise DataLoadError(f"无法加载文件 {file_path},尝试了所有编码")
|
|||
|
|
|
|||
|
|
def get_profile(self) -> DataProfile:
|
|||
|
|
"""
|
|||
|
|
生成数据画像(安全,不包含原始数据)。
|
|||
|
|
|
|||
|
|
返回:
|
|||
|
|
DataProfile: 数据画像,包含元数据和统计摘要
|
|||
|
|
"""
|
|||
|
|
columns_info = []
|
|||
|
|
|
|||
|
|
for col_name in self._data.columns:
|
|||
|
|
col_data = self._data[col_name]
|
|||
|
|
|
|||
|
|
# 推断数据类型
|
|||
|
|
dtype = self._infer_column_type(col_data)
|
|||
|
|
|
|||
|
|
# 计算缺失率
|
|||
|
|
missing_rate = col_data.isna().sum() / len(col_data)
|
|||
|
|
|
|||
|
|
# 计算唯一值数量
|
|||
|
|
unique_count = col_data.nunique()
|
|||
|
|
|
|||
|
|
# 获取示例值(最多5个,去重)
|
|||
|
|
sample_values = col_data.dropna().unique()[:5].tolist()
|
|||
|
|
|
|||
|
|
# 计算统计信息
|
|||
|
|
statistics = {}
|
|||
|
|
if dtype == 'numeric':
|
|||
|
|
statistics = {
|
|||
|
|
'min': float(col_data.min()) if not col_data.isna().all() else None,
|
|||
|
|
'max': float(col_data.max()) if not col_data.isna().all() else None,
|
|||
|
|
'mean': float(col_data.mean()) if not col_data.isna().all() else None,
|
|||
|
|
'std': float(col_data.std()) if not col_data.isna().all() else None,
|
|||
|
|
'median': float(col_data.median()) if not col_data.isna().all() else None,
|
|||
|
|
}
|
|||
|
|
elif dtype == 'categorical':
|
|||
|
|
value_counts = col_data.value_counts().head(10)
|
|||
|
|
statistics = {
|
|||
|
|
'top_values': value_counts.to_dict(),
|
|||
|
|
'num_categories': unique_count,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
columns_info.append(ColumnInfo(
|
|||
|
|
name=col_name,
|
|||
|
|
dtype=dtype,
|
|||
|
|
missing_rate=float(missing_rate),
|
|||
|
|
unique_count=int(unique_count),
|
|||
|
|
sample_values=sample_values,
|
|||
|
|
statistics=statistics
|
|||
|
|
))
|
|||
|
|
|
|||
|
|
return DataProfile(
|
|||
|
|
file_path=self._file_path,
|
|||
|
|
row_count=len(self._data),
|
|||
|
|
column_count=len(self._data.columns),
|
|||
|
|
columns=columns_info,
|
|||
|
|
inferred_type='unknown', # 将由 AI 推断
|
|||
|
|
key_fields={},
|
|||
|
|
quality_score=0.0,
|
|||
|
|
summary=""
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def _infer_column_type(self, col_data: pd.Series) -> str:
|
|||
|
|
"""
|
|||
|
|
推断列的数据类型。
|
|||
|
|
|
|||
|
|
参数:
|
|||
|
|
col_data: 列数据
|
|||
|
|
|
|||
|
|
返回:
|
|||
|
|
数据类型: 'numeric', 'categorical', 'datetime', 'text'
|
|||
|
|
"""
|
|||
|
|
# 检查是否为日期时间类型
|
|||
|
|
if pd.api.types.is_datetime64_any_dtype(col_data):
|
|||
|
|
return 'datetime'
|
|||
|
|
|
|||
|
|
# 尝试转换为日期时间
|
|||
|
|
if col_data.dtype == 'object':
|
|||
|
|
try:
|
|||
|
|
pd.to_datetime(col_data.dropna().head(100))
|
|||
|
|
return 'datetime'
|
|||
|
|
except:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
# 检查是否为数值类型
|
|||
|
|
if pd.api.types.is_numeric_dtype(col_data):
|
|||
|
|
return 'numeric'
|
|||
|
|
|
|||
|
|
# 检查是否为分类类型(唯一值较少)
|
|||
|
|
unique_ratio = col_data.nunique() / len(col_data)
|
|||
|
|
if unique_ratio < 0.05 or col_data.nunique() < 20:
|
|||
|
|
return 'categorical'
|
|||
|
|
|
|||
|
|
# 默认为文本类型
|
|||
|
|
return 'text'
|
|||
|
|
|
|||
|
|
def execute_tool(self, tool: Any, **kwargs) -> Dict[str, Any]:
|
|||
|
|
"""
|
|||
|
|
执行工具并返回聚合结果(安全)。
|
|||
|
|
|
|||
|
|
参数:
|
|||
|
|
tool: 分析工具实例
|
|||
|
|
**kwargs: 工具参数
|
|||
|
|
|
|||
|
|
返回:
|
|||
|
|
工具执行结果(聚合数据)
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
result = tool.execute(self._data, **kwargs)
|
|||
|
|
return self._sanitize_result(result)
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"工具 {tool.name} 执行失败: {e}")
|
|||
|
|
return {
|
|||
|
|
'success': False,
|
|||
|
|
'error': str(e),
|
|||
|
|
'tool': tool.name
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def _sanitize_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
|
|||
|
|
"""
|
|||
|
|
确保结果不包含原始数据,只返回聚合数据。
|
|||
|
|
|
|||
|
|
参数:
|
|||
|
|
result: 工具执行结果
|
|||
|
|
|
|||
|
|
返回:
|
|||
|
|
过滤后的结果
|
|||
|
|
"""
|
|||
|
|
# 检查结果中是否有 DataFrame
|
|||
|
|
sanitized = {}
|
|||
|
|
for key, value in result.items():
|
|||
|
|
if isinstance(value, pd.DataFrame):
|
|||
|
|
# 限制返回的行数
|
|||
|
|
if len(value) > 100:
|
|||
|
|
logger.warning(f"结果包含 {len(value)} 行数据,截断到100行")
|
|||
|
|
value = value.head(100)
|
|||
|
|
sanitized[key] = value.to_dict('records')
|
|||
|
|
elif isinstance(value, pd.Series):
|
|||
|
|
# 限制返回的行数
|
|||
|
|
if len(value) > 100:
|
|||
|
|
logger.warning(f"结果包含 {len(value)} 行数据,截断到100行")
|
|||
|
|
value = value.head(100)
|
|||
|
|
sanitized[key] = value.to_dict()
|
|||
|
|
else:
|
|||
|
|
sanitized[key] = value
|
|||
|
|
|
|||
|
|
return sanitized
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def shape(self) -> tuple:
|
|||
|
|
"""返回数据形状(行数,列数)。"""
|
|||
|
|
return self._data.shape
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def columns(self) -> List[str]:
|
|||
|
|
"""返回列名列表。"""
|
|||
|
|
return self._data.columns.tolist()
|