157 lines
6.1 KiB
Python
157 lines
6.1 KiB
Python
"""数据访问层的基于属性的测试。"""
|
||
|
||
import pytest
|
||
import pandas as pd
|
||
import numpy as np
|
||
from hypothesis import given, strategies as st, settings, HealthCheck
|
||
from typing import Dict, Any
|
||
|
||
from src.data_access import DataAccessLayer
|
||
|
||
|
||
# 生成随机 DataFrame 的策略
|
||
@st.composite
|
||
def dataframe_strategy(draw):
|
||
"""生成随机 DataFrame 用于测试。"""
|
||
n_rows = draw(st.integers(min_value=10, max_value=1000))
|
||
n_cols = draw(st.integers(min_value=2, max_value=20))
|
||
|
||
data = {}
|
||
for i in range(n_cols):
|
||
col_type = draw(st.sampled_from(['int', 'float', 'str', 'datetime']))
|
||
col_name = f'col_{i}'
|
||
|
||
if col_type == 'int':
|
||
data[col_name] = draw(st.lists(
|
||
st.integers(min_value=-1000, max_value=1000),
|
||
min_size=n_rows,
|
||
max_size=n_rows
|
||
))
|
||
elif col_type == 'float':
|
||
data[col_name] = draw(st.lists(
|
||
st.floats(min_value=-1000.0, max_value=1000.0, allow_nan=False, allow_infinity=False),
|
||
min_size=n_rows,
|
||
max_size=n_rows
|
||
))
|
||
elif col_type == 'str':
|
||
data[col_name] = draw(st.lists(
|
||
st.text(min_size=1, max_size=20, alphabet=st.characters(blacklist_categories=('Cs',))),
|
||
min_size=n_rows,
|
||
max_size=n_rows
|
||
))
|
||
else: # datetime
|
||
# 生成日期字符串
|
||
dates = pd.date_range('2020-01-01', periods=n_rows, freq='D')
|
||
data[col_name] = dates.tolist()
|
||
|
||
return pd.DataFrame(data)
|
||
|
||
|
||
class TestDataAccessProperties:
|
||
"""数据访问层的属性测试。"""
|
||
|
||
# Feature: true-ai-agent, Property 18: 数据访问限制
|
||
@given(df=dataframe_strategy())
|
||
@settings(max_examples=20, deadline=None, suppress_health_check=[HealthCheck.data_too_large])
|
||
def test_property_18_data_access_restriction(self, df):
|
||
"""
|
||
属性 18:数据访问限制
|
||
|
||
验证需求:约束条件5.3
|
||
|
||
对于任何数据,数据画像应该只包含元数据和统计摘要,
|
||
不应该包含完整的原始行级数据。
|
||
"""
|
||
# 创建数据访问层
|
||
dal = DataAccessLayer(df, file_path="test.csv")
|
||
|
||
# 获取数据画像
|
||
profile = dal.get_profile()
|
||
|
||
# 验证:数据画像不应包含原始数据
|
||
# 1. 检查行数和列数是元数据
|
||
assert profile.row_count == len(df)
|
||
assert profile.column_count == len(df.columns)
|
||
|
||
# 2. 检查列信息
|
||
assert len(profile.columns) == len(df.columns)
|
||
|
||
for col_info in profile.columns:
|
||
# 3. 示例值应该被限制(最多5个)
|
||
assert len(col_info.sample_values) <= 5
|
||
|
||
# 4. 统计信息应该是聚合数据,不是原始数据
|
||
if col_info.dtype == 'numeric':
|
||
# 统计信息应该是单个值,不是数组
|
||
if col_info.statistics:
|
||
for stat_key, stat_value in col_info.statistics.items():
|
||
assert not isinstance(stat_value, (list, np.ndarray, pd.Series))
|
||
# 应该是标量值或 None
|
||
assert stat_value is None or isinstance(stat_value, (int, float))
|
||
|
||
# 5. 缺失率应该是聚合指标(0-1之间的浮点数)
|
||
assert 0.0 <= col_info.missing_rate <= 1.0
|
||
|
||
# 6. 唯一值数量应该是聚合指标
|
||
assert isinstance(col_info.unique_count, int)
|
||
assert col_info.unique_count >= 0
|
||
|
||
# 7. 验证数据画像的 JSON 序列化不包含大量原始数据
|
||
profile_json = profile.to_json()
|
||
# JSON 大小应该远小于原始数据
|
||
# 原始数据至少有 n_rows * n_cols 个值
|
||
# 数据画像应该只有元数据和少量示例
|
||
original_data_size = len(df) * len(df.columns)
|
||
# 数据画像的大小应该远小于原始数据(至少小于10%)
|
||
assert len(profile_json) < original_data_size * 100 # 粗略估计
|
||
|
||
@given(df=dataframe_strategy())
|
||
@settings(max_examples=10, deadline=None)
|
||
def test_data_profile_completeness(self, df):
|
||
"""
|
||
测试数据画像的完整性。
|
||
|
||
数据画像应该包含所有必需的元数据字段。
|
||
"""
|
||
dal = DataAccessLayer(df, file_path="test.csv")
|
||
profile = dal.get_profile()
|
||
|
||
# 验证必需字段存在
|
||
assert profile.file_path == "test.csv"
|
||
assert profile.row_count > 0
|
||
assert profile.column_count > 0
|
||
assert len(profile.columns) > 0
|
||
assert profile.inferred_type is not None
|
||
|
||
# 验证每个列信息的完整性
|
||
for col_info in profile.columns:
|
||
assert col_info.name is not None
|
||
assert col_info.dtype in ['numeric', 'categorical', 'datetime', 'text']
|
||
assert 0.0 <= col_info.missing_rate <= 1.0
|
||
assert col_info.unique_count >= 0
|
||
assert isinstance(col_info.sample_values, list)
|
||
assert isinstance(col_info.statistics, dict)
|
||
|
||
@given(df=dataframe_strategy())
|
||
@settings(max_examples=10, deadline=None)
|
||
def test_column_type_inference(self, df):
|
||
"""
|
||
测试列类型推断的正确性。
|
||
|
||
推断的类型应该与实际数据类型一致。
|
||
"""
|
||
dal = DataAccessLayer(df, file_path="test.csv")
|
||
profile = dal.get_profile()
|
||
|
||
for i, col_info in enumerate(profile.columns):
|
||
col_name = col_info.name
|
||
actual_dtype = df[col_name].dtype
|
||
|
||
# 验证类型推断的合理性
|
||
if pd.api.types.is_numeric_dtype(actual_dtype):
|
||
assert col_info.dtype in ['numeric', 'categorical']
|
||
elif pd.api.types.is_datetime64_any_dtype(actual_dtype):
|
||
assert col_info.dtype == 'datetime'
|
||
elif pd.api.types.is_object_dtype(actual_dtype):
|
||
assert col_info.dtype in ['categorical', 'text', 'datetime']
|