Files
vibe_data_ana/src/data_access.py

269 lines
9.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""数据访问层 - 提供隐私保护的数据访问接口。"""
import pandas as pd
import logging
from typing import Dict, Any, List, Optional
from pathlib import Path
from src.models import DataProfile, ColumnInfo
logger = logging.getLogger(__name__)
class DataLoadError(Exception):
"""数据加载错误。"""
pass
class DataAccessLayer:
"""
数据访问层,提供隐私保护的数据访问。
核心原则:
- AI 不能直接访问原始数据
- 只能通过工具获取聚合结果
- 数据画像只包含元数据和统计摘要
"""
def __init__(self, data: pd.DataFrame, file_path: str = ""):
"""
初始化数据访问层。
参数:
data: 原始数据(私有,不暴露给 AI
file_path: 数据文件路径
"""
self._data = data # 私有数据AI 不可访问
self._file_path = file_path
@classmethod
def load_from_file(cls, file_path: str, max_retries: int = 3, optimize_memory: bool = True) -> 'DataAccessLayer':
"""
从文件加载数据,支持多种编码和性能优化。
参数:
file_path: CSV 文件路径
max_retries: 最大重试次数
optimize_memory: 是否优化内存使用
返回:
DataAccessLayer 实例
异常:
DataLoadError: 数据加载失败
"""
encodings = ['utf-8', 'gbk', 'gb2312', 'latin1', 'iso-8859-1']
for encoding in encodings:
try:
logger.info(f"尝试使用编码 {encoding} 加载文件: {file_path}")
# 使用低内存模式加载大文件
data = pd.read_csv(file_path, encoding=encoding, low_memory=False)
# 检查数据是否为空
if data.empty:
raise DataLoadError(f"文件 {file_path} 为空")
# 检查数据大小并采样
if len(data) > 1_000_000:
logger.warning(f"数据过大({len(data)}采样到100万行")
data = data.sample(n=1_000_000, random_state=42)
# 优化内存使用
if optimize_memory:
from src.performance_optimization import DataLoadOptimizer
initial_memory = data.memory_usage(deep=True).sum() / 1024 / 1024
data = DataLoadOptimizer.optimize_dtypes(data)
final_memory = data.memory_usage(deep=True).sum() / 1024 / 1024
logger.info(f"内存优化: {initial_memory:.2f}MB -> {final_memory:.2f}MB (节省 {initial_memory - final_memory:.2f}MB)")
logger.info(f"成功加载数据: {len(data)}行, {len(data.columns)}")
return cls(data, file_path)
except UnicodeDecodeError:
logger.debug(f"编码 {encoding} 失败,尝试下一个")
continue
except Exception as e:
logger.error(f"加载文件失败 ({encoding}): {e}")
if encoding == encodings[-1]:
raise DataLoadError(f"无法加载文件 {file_path}: {e}")
continue
raise DataLoadError(f"无法加载文件 {file_path},尝试了所有编码")
def get_profile(self) -> DataProfile:
"""
生成数据画像(安全,不包含原始数据)。
返回:
DataProfile: 数据画像,包含元数据和统计摘要
"""
columns_info = []
for col_name in self._data.columns:
col_data = self._data[col_name]
# 推断数据类型
dtype = self._infer_column_type(col_data)
# 计算缺失率
missing_rate = col_data.isna().sum() / len(col_data)
# 计算唯一值数量
unique_count = col_data.nunique()
# 获取示例值最多5个去重
sample_values = col_data.dropna().unique()[:5].tolist()
# 计算统计信息
statistics = {}
if dtype == 'numeric':
statistics = {
'min': float(col_data.min()) if not col_data.isna().all() else None,
'max': float(col_data.max()) if not col_data.isna().all() else None,
'mean': float(col_data.mean()) if not col_data.isna().all() else None,
'std': float(col_data.std()) if not col_data.isna().all() else None,
'median': float(col_data.median()) if not col_data.isna().all() else None,
}
elif dtype == 'categorical':
value_counts = col_data.value_counts().head(10)
statistics = {
'top_values': value_counts.to_dict(),
'num_categories': unique_count,
}
columns_info.append(ColumnInfo(
name=col_name,
dtype=dtype,
missing_rate=float(missing_rate),
unique_count=int(unique_count),
sample_values=sample_values,
statistics=statistics
))
return DataProfile(
file_path=self._file_path,
row_count=len(self._data),
column_count=len(self._data.columns),
columns=columns_info,
inferred_type='unknown', # 将由 AI 推断
key_fields={},
quality_score=0.0,
summary=""
)
def _infer_column_type(self, col_data: pd.Series) -> str:
"""
推断列的数据类型。
参数:
col_data: 列数据
返回:
数据类型: 'numeric', 'categorical', 'datetime', 'text'
"""
# 检查是否为日期时间类型
if pd.api.types.is_datetime64_any_dtype(col_data):
return 'datetime'
# 尝试转换为日期时间
if col_data.dtype == 'object':
try:
sample = col_data.dropna().head(20)
if len(sample) == 0:
pass
else:
# 尝试用常见日期格式解析
date_formats = ['%Y-%m-%d', '%Y/%m/%d', '%Y-%m-%d %H:%M:%S', '%Y/%m/%d %H:%M:%S', '%d/%m/%Y', '%m/%d/%Y']
parsed = False
for fmt in date_formats:
try:
pd.to_datetime(sample, format=fmt)
parsed = True
break
except (ValueError, TypeError):
continue
if not parsed:
# 最后尝试自动推断,但用 infer_datetime_format
pd.to_datetime(sample, format='mixed', dayfirst=False)
parsed = True
if parsed:
return 'datetime'
except:
pass
# 检查是否为数值类型
if pd.api.types.is_numeric_dtype(col_data):
return 'numeric'
# 检查是否为分类类型(唯一值较少)
unique_ratio = col_data.nunique() / len(col_data)
if unique_ratio < 0.05 or col_data.nunique() < 20:
return 'categorical'
# 默认为文本类型
return 'text'
def execute_tool(self, tool: Any, **kwargs) -> Dict[str, Any]:
"""
执行工具并返回聚合结果(安全)。
参数:
tool: 分析工具实例
**kwargs: 工具参数
返回:
工具执行结果(聚合数据)
"""
try:
result = tool.execute(self._data, **kwargs)
return self._sanitize_result(result)
except Exception as e:
logger.error(f"工具 {tool.name} 执行失败: {e}")
return {
'success': False,
'error': str(e),
'tool': tool.name
}
def _sanitize_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
"""
确保结果不包含原始数据,只返回聚合数据。
参数:
result: 工具执行结果
返回:
过滤后的结果
"""
# 检查结果中是否有 DataFrame
sanitized = {}
for key, value in result.items():
if isinstance(value, pd.DataFrame):
# 限制返回的行数
if len(value) > 100:
logger.warning(f"结果包含 {len(value)} 行数据截断到100行")
value = value.head(100)
sanitized[key] = value.to_dict('records')
elif isinstance(value, pd.Series):
# 限制返回的行数
if len(value) > 100:
logger.warning(f"结果包含 {len(value)} 行数据截断到100行")
value = value.head(100)
sanitized[key] = value.to_dict()
else:
sanitized[key] = value
return sanitized
@property
def shape(self) -> tuple:
"""返回数据形状(行数,列数)。"""
return self._data.shape
@property
def columns(self) -> List[str]:
"""返回列名列表。"""
return self._data.columns.tolist()