274 lines
12 KiB
Python
274 lines
12 KiB
Python
"""数据理解引擎的基于属性的测试。"""
|
||
|
||
import pytest
|
||
import pandas as pd
|
||
import numpy as np
|
||
from hypothesis import given, strategies as st, settings, assume
|
||
from typing import Dict, Any
|
||
|
||
from src.engines.data_understanding import (
|
||
generate_basic_stats,
|
||
understand_data,
|
||
_infer_column_type,
|
||
_infer_data_type,
|
||
_identify_key_fields,
|
||
_evaluate_data_quality
|
||
)
|
||
from src.models import DataProfile, ColumnInfo
|
||
|
||
|
||
# Hypothesis 策略用于生成测试数据
|
||
|
||
@st.composite
|
||
def dataframe_strategy(draw, min_rows=10, max_rows=100, min_cols=2, max_cols=10):
|
||
"""生成随机的 DataFrame 实例。"""
|
||
n_rows = draw(st.integers(min_value=min_rows, max_value=max_rows))
|
||
n_cols = draw(st.integers(min_value=min_cols, max_value=max_cols))
|
||
|
||
data = {}
|
||
for i in range(n_cols):
|
||
col_type = draw(st.sampled_from(['int', 'float', 'str', 'datetime']))
|
||
col_name = f'col_{i}'
|
||
|
||
if col_type == 'int':
|
||
data[col_name] = draw(st.lists(
|
||
st.integers(min_value=-1000, max_value=1000),
|
||
min_size=n_rows,
|
||
max_size=n_rows
|
||
))
|
||
elif col_type == 'float':
|
||
data[col_name] = draw(st.lists(
|
||
st.floats(min_value=-1000.0, max_value=1000.0, allow_nan=False, allow_infinity=False),
|
||
min_size=n_rows,
|
||
max_size=n_rows
|
||
))
|
||
elif col_type == 'datetime':
|
||
start_date = pd.Timestamp('2020-01-01')
|
||
data[col_name] = pd.date_range(start=start_date, periods=n_rows, freq='D')
|
||
else: # str
|
||
data[col_name] = draw(st.lists(
|
||
st.text(min_size=1, max_size=10, alphabet=st.characters(whitelist_categories=('Lu', 'Ll'))),
|
||
min_size=n_rows,
|
||
max_size=n_rows
|
||
))
|
||
|
||
return pd.DataFrame(data)
|
||
|
||
|
||
# Feature: true-ai-agent, Property 1: 数据类型识别
|
||
@given(df=dataframe_strategy(min_rows=10, max_rows=100))
|
||
@settings(max_examples=20, deadline=None)
|
||
def test_data_type_inference(df):
|
||
"""
|
||
属性 1:对于任何有效的 CSV 文件,数据理解引擎应该能够推断出数据的业务类型
|
||
(如工单、销售、用户等),并且推断结果应该基于列名、数据类型和值分布的分析。
|
||
|
||
验证需求:场景1验收.1
|
||
"""
|
||
# 执行数据理解
|
||
profile = understand_data(file_path='test.csv', data=df)
|
||
|
||
# 验证:应该有推断的类型
|
||
assert profile.inferred_type is not None, "推断的数据类型不应为 None"
|
||
assert profile.inferred_type in ['ticket', 'sales', 'user', 'unknown'], \
|
||
f"推断的数据类型应该是预定义的类型之一,但得到:{profile.inferred_type}"
|
||
|
||
# 验证:推断应该基于数据特征
|
||
# 至少应该识别出一些关键字段或生成摘要
|
||
assert len(profile.summary) > 0, "应该生成数据摘要"
|
||
|
||
|
||
# Feature: true-ai-agent, Property 2: 数据画像完整性
|
||
@given(df=dataframe_strategy(min_rows=5, max_rows=50))
|
||
@settings(max_examples=20, deadline=None)
|
||
def test_data_profile_completeness(df):
|
||
"""
|
||
属性 2:对于任何有效的 CSV 文件,生成的数据画像应该包含所有必需字段
|
||
(行数、列数、列信息、推断类型、关键字段、质量分数),并且列信息应该
|
||
包含每列的名称、类型、缺失率和统计信息。
|
||
|
||
验证需求:FR-1.2, FR-1.3, FR-1.4
|
||
"""
|
||
# 执行数据理解
|
||
profile = understand_data(file_path='test.csv', data=df)
|
||
|
||
# 验证:数据画像应该包含所有必需字段
|
||
assert hasattr(profile, 'file_path'), "数据画像缺少 file_path 字段"
|
||
assert hasattr(profile, 'row_count'), "数据画像缺少 row_count 字段"
|
||
assert hasattr(profile, 'column_count'), "数据画像缺少 column_count 字段"
|
||
assert hasattr(profile, 'columns'), "数据画像缺少 columns 字段"
|
||
assert hasattr(profile, 'inferred_type'), "数据画像缺少 inferred_type 字段"
|
||
assert hasattr(profile, 'key_fields'), "数据画像缺少 key_fields 字段"
|
||
assert hasattr(profile, 'quality_score'), "数据画像缺少 quality_score 字段"
|
||
assert hasattr(profile, 'summary'), "数据画像缺少 summary 字段"
|
||
|
||
# 验证:行数和列数应该正确
|
||
assert profile.row_count == len(df), f"行数不匹配:期望 {len(df)},得到 {profile.row_count}"
|
||
assert profile.column_count == len(df.columns), \
|
||
f"列数不匹配:期望 {len(df.columns)},得到 {profile.column_count}"
|
||
|
||
# 验证:列信息应该完整
|
||
assert len(profile.columns) == len(df.columns), \
|
||
f"列信息数量不匹配:期望 {len(df.columns)},得到 {len(profile.columns)}"
|
||
|
||
for col_info in profile.columns:
|
||
# 验证:每列应该有名称、类型、缺失率
|
||
assert hasattr(col_info, 'name'), "列信息缺少 name 字段"
|
||
assert hasattr(col_info, 'dtype'), "列信息缺少 dtype 字段"
|
||
assert hasattr(col_info, 'missing_rate'), "列信息缺少 missing_rate 字段"
|
||
assert hasattr(col_info, 'unique_count'), "列信息缺少 unique_count 字段"
|
||
assert hasattr(col_info, 'statistics'), "列信息缺少 statistics 字段"
|
||
|
||
# 验证:数据类型应该是预定义的类型之一
|
||
assert col_info.dtype in ['numeric', 'categorical', 'datetime', 'text'], \
|
||
f"列 {col_info.name} 的数据类型应该是预定义的类型之一,但得到:{col_info.dtype}"
|
||
|
||
# 验证:缺失率应该在 0-1 之间
|
||
assert 0.0 <= col_info.missing_rate <= 1.0, \
|
||
f"列 {col_info.name} 的缺失率应该在 0-1 之间,但得到:{col_info.missing_rate}"
|
||
|
||
# 验证:唯一值数量应该合理
|
||
assert col_info.unique_count >= 0, \
|
||
f"列 {col_info.name} 的唯一值数量应该非负,但得到:{col_info.unique_count}"
|
||
assert col_info.unique_count <= len(df), \
|
||
f"列 {col_info.name} 的唯一值数量不应超过总行数"
|
||
|
||
# 验证:质量分数应该在 0-100 之间
|
||
assert 0.0 <= profile.quality_score <= 100.0, \
|
||
f"质量分数应该在 0-100 之间,但得到:{profile.quality_score}"
|
||
|
||
|
||
# 额外测试:验证列类型推断的正确性
|
||
@given(
|
||
numeric_data=st.lists(st.floats(min_value=-1000, max_value=1000, allow_nan=False, allow_infinity=False),
|
||
min_size=10, max_size=100),
|
||
categorical_data=st.lists(st.sampled_from(['A', 'B', 'C', 'D']), min_size=10, max_size=100)
|
||
)
|
||
@settings(max_examples=10)
|
||
def test_column_type_inference(numeric_data, categorical_data):
|
||
"""测试列类型推断的正确性。"""
|
||
# 测试数值列
|
||
numeric_series = pd.Series(numeric_data)
|
||
numeric_type = _infer_column_type(numeric_series)
|
||
assert numeric_type == 'numeric', f"数值列应该被识别为 'numeric',但得到:{numeric_type}"
|
||
|
||
# 测试分类列
|
||
categorical_series = pd.Series(categorical_data)
|
||
categorical_type = _infer_column_type(categorical_series)
|
||
assert categorical_type == 'categorical', \
|
||
f"分类列应该被识别为 'categorical',但得到:{categorical_type}"
|
||
|
||
|
||
# 额外测试:验证数据质量评估的合理性
|
||
@given(
|
||
missing_rate=st.floats(min_value=0.0, max_value=1.0),
|
||
n_cols=st.integers(min_value=1, max_value=10)
|
||
)
|
||
@settings(max_examples=10)
|
||
def test_data_quality_evaluation(missing_rate, n_cols):
|
||
"""测试数据质量评估的合理性。"""
|
||
# 创建具有指定缺失率的列信息
|
||
columns = []
|
||
for i in range(n_cols):
|
||
col_info = ColumnInfo(
|
||
name=f'col_{i}',
|
||
dtype='numeric',
|
||
missing_rate=missing_rate,
|
||
unique_count=100,
|
||
sample_values=[1, 2, 3],
|
||
statistics={}
|
||
)
|
||
columns.append(col_info)
|
||
|
||
# 评估数据质量
|
||
quality_score = _evaluate_data_quality(columns, row_count=100)
|
||
|
||
# 验证:质量分数应该在 0-100 之间
|
||
assert 0.0 <= quality_score <= 100.0, \
|
||
f"质量分数应该在 0-100 之间,但得到:{quality_score}"
|
||
|
||
# 验证:缺失率越高,质量分数应该越低
|
||
if missing_rate > 0.5:
|
||
assert quality_score < 70, \
|
||
f"高缺失率({missing_rate})应该导致较低的质量分数,但得到:{quality_score}"
|
||
|
||
|
||
# 额外测试:验证基础统计生成的完整性
|
||
@given(df=dataframe_strategy(min_rows=5, max_rows=50))
|
||
@settings(max_examples=10, deadline=None)
|
||
def test_basic_stats_generation(df):
|
||
"""测试基础统计生成的完整性。"""
|
||
# 生成基础统计
|
||
stats = generate_basic_stats(df, file_path='test.csv')
|
||
|
||
# 验证:应该包含必需字段
|
||
assert 'file_path' in stats, "基础统计缺少 file_path 字段"
|
||
assert 'row_count' in stats, "基础统计缺少 row_count 字段"
|
||
assert 'column_count' in stats, "基础统计缺少 column_count 字段"
|
||
assert 'columns' in stats, "基础统计缺少 columns 字段"
|
||
|
||
# 验证:统计信息应该准确
|
||
assert stats['row_count'] == len(df), "行数统计不准确"
|
||
assert stats['column_count'] == len(df.columns), "列数统计不准确"
|
||
assert len(stats['columns']) == len(df.columns), "列信息数量不匹配"
|
||
|
||
|
||
# 额外测试:验证关键字段识别
|
||
def test_key_field_identification():
|
||
"""测试关键字段识别功能。"""
|
||
# 创建包含典型字段名的列信息
|
||
columns = [
|
||
ColumnInfo(name='created_at', dtype='datetime', missing_rate=0.0, unique_count=100),
|
||
ColumnInfo(name='status', dtype='categorical', missing_rate=0.0, unique_count=5),
|
||
ColumnInfo(name='ticket_id', dtype='text', missing_rate=0.0, unique_count=100),
|
||
ColumnInfo(name='amount', dtype='numeric', missing_rate=0.0, unique_count=50),
|
||
]
|
||
|
||
# 识别关键字段
|
||
key_fields = _identify_key_fields(columns)
|
||
|
||
# 验证:应该识别出时间字段
|
||
assert 'created_at' in key_fields, "应该识别出 created_at 为关键字段"
|
||
|
||
# 验证:应该识别出状态字段
|
||
assert 'status' in key_fields, "应该识别出 status 为关键字段"
|
||
|
||
# 验证:应该识别出ID字段
|
||
assert 'ticket_id' in key_fields, "应该识别出 ticket_id 为关键字段"
|
||
|
||
# 验证:应该识别出金额字段
|
||
assert 'amount' in key_fields, "应该识别出 amount 为关键字段"
|
||
|
||
|
||
# 额外测试:验证数据类型推断
|
||
def test_data_type_inference_with_keywords():
|
||
"""测试基于关键词的数据类型推断。"""
|
||
# 工单数据
|
||
ticket_columns = [
|
||
ColumnInfo(name='ticket_id', dtype='text', missing_rate=0.0, unique_count=100),
|
||
ColumnInfo(name='status', dtype='categorical', missing_rate=0.0, unique_count=5),
|
||
ColumnInfo(name='created_at', dtype='datetime', missing_rate=0.0, unique_count=100),
|
||
]
|
||
ticket_type = _infer_data_type(ticket_columns)
|
||
assert ticket_type == 'ticket', f"应该识别为工单数据,但得到:{ticket_type}"
|
||
|
||
# 销售数据
|
||
sales_columns = [
|
||
ColumnInfo(name='order_id', dtype='text', missing_rate=0.0, unique_count=100),
|
||
ColumnInfo(name='product', dtype='categorical', missing_rate=0.0, unique_count=10),
|
||
ColumnInfo(name='amount', dtype='numeric', missing_rate=0.0, unique_count=50),
|
||
ColumnInfo(name='sales_date', dtype='datetime', missing_rate=0.0, unique_count=100),
|
||
]
|
||
sales_type = _infer_data_type(sales_columns)
|
||
assert sales_type == 'sales', f"应该识别为销售数据,但得到:{sales_type}"
|
||
|
||
# 用户数据
|
||
user_columns = [
|
||
ColumnInfo(name='user_id', dtype='text', missing_rate=0.0, unique_count=100),
|
||
ColumnInfo(name='name', dtype='text', missing_rate=0.0, unique_count=100),
|
||
ColumnInfo(name='email', dtype='text', missing_rate=0.0, unique_count=100),
|
||
ColumnInfo(name='age', dtype='numeric', missing_rate=0.0, unique_count=50),
|
||
]
|
||
user_type = _infer_data_type(user_columns)
|
||
assert user_type == 'user', f"应该识别为用户数据,但得到:{user_type}"
|