312 lines
10 KiB
Python
312 lines
10 KiB
Python
"""数据理解引擎的单元测试。"""
|
||
|
||
import pytest
|
||
import pandas as pd
|
||
import numpy as np
|
||
from datetime import datetime, timedelta
|
||
|
||
from src.engines.data_understanding import (
|
||
generate_basic_stats,
|
||
understand_data,
|
||
_infer_column_type,
|
||
_infer_data_type,
|
||
_identify_key_fields,
|
||
_evaluate_data_quality,
|
||
_get_sample_values,
|
||
_generate_column_statistics
|
||
)
|
||
from src.models import DataProfile, ColumnInfo
|
||
|
||
|
||
class TestGenerateBasicStats:
|
||
"""测试基础统计生成。"""
|
||
|
||
def test_basic_functionality(self):
|
||
"""测试基本功能。"""
|
||
df = pd.DataFrame({
|
||
'id': [1, 2, 3, 4, 5],
|
||
'name': ['A', 'B', 'C', 'D', 'E'],
|
||
'value': [10.5, 20.3, 30.1, 40.8, 50.2]
|
||
})
|
||
|
||
stats = generate_basic_stats(df, 'test.csv')
|
||
|
||
assert stats['file_path'] == 'test.csv'
|
||
assert stats['row_count'] == 5
|
||
assert stats['column_count'] == 3
|
||
assert len(stats['columns']) == 3
|
||
|
||
def test_empty_dataframe(self):
|
||
"""测试空 DataFrame。"""
|
||
df = pd.DataFrame()
|
||
|
||
stats = generate_basic_stats(df, 'empty.csv')
|
||
|
||
assert stats['row_count'] == 0
|
||
assert stats['column_count'] == 0
|
||
assert len(stats['columns']) == 0
|
||
|
||
|
||
class TestInferColumnType:
|
||
"""测试列类型推断。"""
|
||
|
||
def test_numeric_column(self):
|
||
"""测试数值列。"""
|
||
col = pd.Series([1, 2, 3, 4, 5])
|
||
dtype = _infer_column_type(col)
|
||
assert dtype == 'numeric'
|
||
|
||
def test_categorical_column(self):
|
||
"""测试分类列。"""
|
||
col = pd.Series(['A', 'B', 'A', 'C', 'B', 'A', 'A', 'B', 'C', 'A']) # 10个值,3个唯一值,比例30%
|
||
dtype = _infer_column_type(col)
|
||
assert dtype == 'categorical'
|
||
|
||
def test_datetime_column(self):
|
||
"""测试日期时间列。"""
|
||
col = pd.Series(pd.date_range('2020-01-01', periods=5))
|
||
dtype = _infer_column_type(col)
|
||
assert dtype == 'datetime'
|
||
|
||
def test_text_column(self):
|
||
"""测试文本列(唯一值多)。"""
|
||
col = pd.Series([f'text_{i}' for i in range(100)])
|
||
dtype = _infer_column_type(col)
|
||
assert dtype == 'text'
|
||
|
||
|
||
class TestInferDataType:
|
||
"""测试数据类型推断。"""
|
||
|
||
def test_ticket_data(self):
|
||
"""测试工单数据识别。"""
|
||
columns = [
|
||
ColumnInfo(name='ticket_id', dtype='text', missing_rate=0.0, unique_count=100),
|
||
ColumnInfo(name='status', dtype='categorical', missing_rate=0.0, unique_count=5),
|
||
ColumnInfo(name='created_at', dtype='datetime', missing_rate=0.0, unique_count=100),
|
||
]
|
||
|
||
data_type = _infer_data_type(columns)
|
||
assert data_type == 'ticket'
|
||
|
||
def test_sales_data(self):
|
||
"""测试销售数据识别。"""
|
||
columns = [
|
||
ColumnInfo(name='order_id', dtype='text', missing_rate=0.0, unique_count=100),
|
||
ColumnInfo(name='product', dtype='categorical', missing_rate=0.0, unique_count=10),
|
||
ColumnInfo(name='amount', dtype='numeric', missing_rate=0.0, unique_count=50),
|
||
]
|
||
|
||
data_type = _infer_data_type(columns)
|
||
assert data_type == 'sales'
|
||
|
||
def test_user_data(self):
|
||
"""测试用户数据识别。"""
|
||
columns = [
|
||
ColumnInfo(name='user_id', dtype='text', missing_rate=0.0, unique_count=100),
|
||
ColumnInfo(name='name', dtype='text', missing_rate=0.0, unique_count=100),
|
||
ColumnInfo(name='email', dtype='text', missing_rate=0.0, unique_count=100),
|
||
]
|
||
|
||
data_type = _infer_data_type(columns)
|
||
assert data_type == 'user'
|
||
|
||
def test_unknown_data(self):
|
||
"""测试未知数据类型。"""
|
||
columns = [
|
||
ColumnInfo(name='col1', dtype='numeric', missing_rate=0.0, unique_count=100),
|
||
ColumnInfo(name='col2', dtype='numeric', missing_rate=0.0, unique_count=100),
|
||
]
|
||
|
||
data_type = _infer_data_type(columns)
|
||
assert data_type == 'unknown'
|
||
|
||
|
||
class TestIdentifyKeyFields:
|
||
"""测试关键字段识别。"""
|
||
|
||
def test_time_fields(self):
|
||
"""测试时间字段识别。"""
|
||
columns = [
|
||
ColumnInfo(name='created_at', dtype='datetime', missing_rate=0.0, unique_count=100),
|
||
ColumnInfo(name='closed_at', dtype='datetime', missing_rate=0.0, unique_count=100),
|
||
]
|
||
|
||
key_fields = _identify_key_fields(columns)
|
||
|
||
assert 'created_at' in key_fields
|
||
assert 'closed_at' in key_fields
|
||
assert '创建时间' in key_fields['created_at']
|
||
assert '完成时间' in key_fields['closed_at']
|
||
|
||
def test_status_field(self):
|
||
"""测试状态字段识别。"""
|
||
columns = [
|
||
ColumnInfo(name='status', dtype='categorical', missing_rate=0.0, unique_count=5),
|
||
]
|
||
|
||
key_fields = _identify_key_fields(columns)
|
||
|
||
assert 'status' in key_fields
|
||
assert '状态' in key_fields['status']
|
||
|
||
def test_id_field(self):
|
||
"""测试ID字段识别。"""
|
||
columns = [
|
||
ColumnInfo(name='ticket_id', dtype='text', missing_rate=0.0, unique_count=100),
|
||
]
|
||
|
||
key_fields = _identify_key_fields(columns)
|
||
|
||
assert 'ticket_id' in key_fields
|
||
assert '标识符' in key_fields['ticket_id']
|
||
|
||
|
||
class TestEvaluateDataQuality:
|
||
"""测试数据质量评估。"""
|
||
|
||
def test_high_quality_data(self):
|
||
"""测试高质量数据。"""
|
||
columns = [
|
||
ColumnInfo(name='col1', dtype='numeric', missing_rate=0.0, unique_count=100),
|
||
ColumnInfo(name='col2', dtype='categorical', missing_rate=0.0, unique_count=5),
|
||
]
|
||
|
||
quality_score = _evaluate_data_quality(columns, row_count=100)
|
||
|
||
assert quality_score >= 80
|
||
|
||
def test_low_quality_data(self):
|
||
"""测试低质量数据(高缺失率)。"""
|
||
columns = [
|
||
ColumnInfo(name='col1', dtype='numeric', missing_rate=0.8, unique_count=20),
|
||
ColumnInfo(name='col2', dtype='categorical', missing_rate=0.9, unique_count=2),
|
||
]
|
||
|
||
quality_score = _evaluate_data_quality(columns, row_count=100)
|
||
|
||
assert quality_score < 50
|
||
|
||
def test_empty_data(self):
|
||
"""测试空数据。"""
|
||
columns = []
|
||
|
||
quality_score = _evaluate_data_quality(columns, row_count=0)
|
||
|
||
assert quality_score == 0.0
|
||
|
||
|
||
class TestGetSampleValues:
|
||
"""测试示例值获取。"""
|
||
|
||
def test_basic_functionality(self):
|
||
"""测试基本功能。"""
|
||
col = pd.Series([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
||
|
||
samples = _get_sample_values(col, max_samples=5)
|
||
|
||
assert len(samples) <= 5
|
||
assert all(isinstance(s, (int, float)) for s in samples)
|
||
|
||
def test_with_null_values(self):
|
||
"""测试包含空值的情况。"""
|
||
col = pd.Series([1, 2, None, 4, None, 6])
|
||
|
||
samples = _get_sample_values(col, max_samples=5)
|
||
|
||
assert len(samples) <= 4 # 排除了空值
|
||
|
||
def test_datetime_values(self):
|
||
"""测试日期时间值。"""
|
||
col = pd.Series(pd.date_range('2020-01-01', periods=5))
|
||
|
||
samples = _get_sample_values(col, max_samples=3)
|
||
|
||
assert len(samples) <= 3
|
||
assert all(isinstance(s, str) for s in samples)
|
||
|
||
|
||
class TestGenerateColumnStatistics:
|
||
"""测试列统计信息生成。"""
|
||
|
||
def test_numeric_statistics(self):
|
||
"""测试数值列统计。"""
|
||
col = pd.Series([1, 2, 3, 4, 5])
|
||
|
||
stats = _generate_column_statistics(col, 'numeric')
|
||
|
||
assert 'mean' in stats
|
||
assert 'median' in stats
|
||
assert 'std' in stats
|
||
assert 'min' in stats
|
||
assert 'max' in stats
|
||
assert stats['mean'] == 3.0
|
||
assert stats['min'] == 1.0
|
||
assert stats['max'] == 5.0
|
||
|
||
def test_categorical_statistics(self):
|
||
"""测试分类列统计。"""
|
||
col = pd.Series(['A', 'B', 'A', 'C', 'A'])
|
||
|
||
stats = _generate_column_statistics(col, 'categorical')
|
||
|
||
assert 'most_common' in stats
|
||
assert 'most_common_count' in stats
|
||
assert stats['most_common'] == 'A'
|
||
assert stats['most_common_count'] == 3
|
||
|
||
def test_datetime_statistics(self):
|
||
"""测试日期时间列统计。"""
|
||
col = pd.Series(pd.date_range('2020-01-01', periods=10))
|
||
|
||
stats = _generate_column_statistics(col, 'datetime')
|
||
|
||
assert 'min_date' in stats
|
||
assert 'max_date' in stats
|
||
assert 'date_range_days' in stats
|
||
|
||
def test_text_statistics(self):
|
||
"""测试文本列统计。"""
|
||
col = pd.Series(['hello', 'world', 'test'])
|
||
|
||
stats = _generate_column_statistics(col, 'text')
|
||
|
||
assert 'avg_length' in stats
|
||
assert 'max_length' in stats
|
||
|
||
|
||
class TestUnderstandData:
|
||
"""测试完整的数据理解流程。"""
|
||
|
||
def test_basic_functionality(self):
|
||
"""测试基本功能。"""
|
||
df = pd.DataFrame({
|
||
'ticket_id': [1, 2, 3, 4, 5],
|
||
'status': ['open', 'closed', 'open', 'pending', 'closed'],
|
||
'created_at': pd.date_range('2020-01-01', periods=5),
|
||
'amount': [100, 200, 150, 300, 250]
|
||
})
|
||
|
||
profile = understand_data('test.csv', data=df)
|
||
|
||
assert isinstance(profile, DataProfile)
|
||
assert profile.row_count == 5
|
||
assert profile.column_count == 4
|
||
assert len(profile.columns) == 4
|
||
assert profile.inferred_type in ['ticket', 'sales', 'user', 'unknown']
|
||
assert 0 <= profile.quality_score <= 100
|
||
assert len(profile.summary) > 0
|
||
|
||
def test_with_missing_values(self):
|
||
"""测试包含缺失值的数据。"""
|
||
df = pd.DataFrame({
|
||
'col1': [1, 2, None, 4, 5],
|
||
'col2': ['A', None, 'C', 'D', None]
|
||
})
|
||
|
||
profile = understand_data('test.csv', data=df)
|
||
|
||
assert profile.row_count == 5
|
||
# 质量分数应该因为缺失值而降低
|
||
assert profile.quality_score < 100
|