Files
vibe_data_ana/tests/test_data_understanding.py

312 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""数据理解引擎的单元测试。"""
import pytest
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from src.engines.data_understanding import (
generate_basic_stats,
understand_data,
_infer_column_type,
_infer_data_type,
_identify_key_fields,
_evaluate_data_quality,
_get_sample_values,
_generate_column_statistics
)
from src.models import DataProfile, ColumnInfo
class TestGenerateBasicStats:
"""测试基础统计生成。"""
def test_basic_functionality(self):
"""测试基本功能。"""
df = pd.DataFrame({
'id': [1, 2, 3, 4, 5],
'name': ['A', 'B', 'C', 'D', 'E'],
'value': [10.5, 20.3, 30.1, 40.8, 50.2]
})
stats = generate_basic_stats(df, 'test.csv')
assert stats['file_path'] == 'test.csv'
assert stats['row_count'] == 5
assert stats['column_count'] == 3
assert len(stats['columns']) == 3
def test_empty_dataframe(self):
"""测试空 DataFrame。"""
df = pd.DataFrame()
stats = generate_basic_stats(df, 'empty.csv')
assert stats['row_count'] == 0
assert stats['column_count'] == 0
assert len(stats['columns']) == 0
class TestInferColumnType:
"""测试列类型推断。"""
def test_numeric_column(self):
"""测试数值列。"""
col = pd.Series([1, 2, 3, 4, 5])
dtype = _infer_column_type(col)
assert dtype == 'numeric'
def test_categorical_column(self):
"""测试分类列。"""
col = pd.Series(['A', 'B', 'A', 'C', 'B', 'A', 'A', 'B', 'C', 'A']) # 10个值3个唯一值比例30%
dtype = _infer_column_type(col)
assert dtype == 'categorical'
def test_datetime_column(self):
"""测试日期时间列。"""
col = pd.Series(pd.date_range('2020-01-01', periods=5))
dtype = _infer_column_type(col)
assert dtype == 'datetime'
def test_text_column(self):
"""测试文本列(唯一值多)。"""
col = pd.Series([f'text_{i}' for i in range(100)])
dtype = _infer_column_type(col)
assert dtype == 'text'
class TestInferDataType:
"""测试数据类型推断。"""
def test_ticket_data(self):
"""测试工单数据识别。"""
columns = [
ColumnInfo(name='ticket_id', dtype='text', missing_rate=0.0, unique_count=100),
ColumnInfo(name='status', dtype='categorical', missing_rate=0.0, unique_count=5),
ColumnInfo(name='created_at', dtype='datetime', missing_rate=0.0, unique_count=100),
]
data_type = _infer_data_type(columns)
assert data_type == 'ticket'
def test_sales_data(self):
"""测试销售数据识别。"""
columns = [
ColumnInfo(name='order_id', dtype='text', missing_rate=0.0, unique_count=100),
ColumnInfo(name='product', dtype='categorical', missing_rate=0.0, unique_count=10),
ColumnInfo(name='amount', dtype='numeric', missing_rate=0.0, unique_count=50),
]
data_type = _infer_data_type(columns)
assert data_type == 'sales'
def test_user_data(self):
"""测试用户数据识别。"""
columns = [
ColumnInfo(name='user_id', dtype='text', missing_rate=0.0, unique_count=100),
ColumnInfo(name='name', dtype='text', missing_rate=0.0, unique_count=100),
ColumnInfo(name='email', dtype='text', missing_rate=0.0, unique_count=100),
]
data_type = _infer_data_type(columns)
assert data_type == 'user'
def test_unknown_data(self):
"""测试未知数据类型。"""
columns = [
ColumnInfo(name='col1', dtype='numeric', missing_rate=0.0, unique_count=100),
ColumnInfo(name='col2', dtype='numeric', missing_rate=0.0, unique_count=100),
]
data_type = _infer_data_type(columns)
assert data_type == 'unknown'
class TestIdentifyKeyFields:
"""测试关键字段识别。"""
def test_time_fields(self):
"""测试时间字段识别。"""
columns = [
ColumnInfo(name='created_at', dtype='datetime', missing_rate=0.0, unique_count=100),
ColumnInfo(name='closed_at', dtype='datetime', missing_rate=0.0, unique_count=100),
]
key_fields = _identify_key_fields(columns)
assert 'created_at' in key_fields
assert 'closed_at' in key_fields
assert '创建时间' in key_fields['created_at']
assert '完成时间' in key_fields['closed_at']
def test_status_field(self):
"""测试状态字段识别。"""
columns = [
ColumnInfo(name='status', dtype='categorical', missing_rate=0.0, unique_count=5),
]
key_fields = _identify_key_fields(columns)
assert 'status' in key_fields
assert '状态' in key_fields['status']
def test_id_field(self):
"""测试ID字段识别。"""
columns = [
ColumnInfo(name='ticket_id', dtype='text', missing_rate=0.0, unique_count=100),
]
key_fields = _identify_key_fields(columns)
assert 'ticket_id' in key_fields
assert '标识符' in key_fields['ticket_id']
class TestEvaluateDataQuality:
"""测试数据质量评估。"""
def test_high_quality_data(self):
"""测试高质量数据。"""
columns = [
ColumnInfo(name='col1', dtype='numeric', missing_rate=0.0, unique_count=100),
ColumnInfo(name='col2', dtype='categorical', missing_rate=0.0, unique_count=5),
]
quality_score = _evaluate_data_quality(columns, row_count=100)
assert quality_score >= 80
def test_low_quality_data(self):
"""测试低质量数据(高缺失率)。"""
columns = [
ColumnInfo(name='col1', dtype='numeric', missing_rate=0.8, unique_count=20),
ColumnInfo(name='col2', dtype='categorical', missing_rate=0.9, unique_count=2),
]
quality_score = _evaluate_data_quality(columns, row_count=100)
assert quality_score < 50
def test_empty_data(self):
"""测试空数据。"""
columns = []
quality_score = _evaluate_data_quality(columns, row_count=0)
assert quality_score == 0.0
class TestGetSampleValues:
"""测试示例值获取。"""
def test_basic_functionality(self):
"""测试基本功能。"""
col = pd.Series([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
samples = _get_sample_values(col, max_samples=5)
assert len(samples) <= 5
assert all(isinstance(s, (int, float)) for s in samples)
def test_with_null_values(self):
"""测试包含空值的情况。"""
col = pd.Series([1, 2, None, 4, None, 6])
samples = _get_sample_values(col, max_samples=5)
assert len(samples) <= 4 # 排除了空值
def test_datetime_values(self):
"""测试日期时间值。"""
col = pd.Series(pd.date_range('2020-01-01', periods=5))
samples = _get_sample_values(col, max_samples=3)
assert len(samples) <= 3
assert all(isinstance(s, str) for s in samples)
class TestGenerateColumnStatistics:
"""测试列统计信息生成。"""
def test_numeric_statistics(self):
"""测试数值列统计。"""
col = pd.Series([1, 2, 3, 4, 5])
stats = _generate_column_statistics(col, 'numeric')
assert 'mean' in stats
assert 'median' in stats
assert 'std' in stats
assert 'min' in stats
assert 'max' in stats
assert stats['mean'] == 3.0
assert stats['min'] == 1.0
assert stats['max'] == 5.0
def test_categorical_statistics(self):
"""测试分类列统计。"""
col = pd.Series(['A', 'B', 'A', 'C', 'A'])
stats = _generate_column_statistics(col, 'categorical')
assert 'most_common' in stats
assert 'most_common_count' in stats
assert stats['most_common'] == 'A'
assert stats['most_common_count'] == 3
def test_datetime_statistics(self):
"""测试日期时间列统计。"""
col = pd.Series(pd.date_range('2020-01-01', periods=10))
stats = _generate_column_statistics(col, 'datetime')
assert 'min_date' in stats
assert 'max_date' in stats
assert 'date_range_days' in stats
def test_text_statistics(self):
"""测试文本列统计。"""
col = pd.Series(['hello', 'world', 'test'])
stats = _generate_column_statistics(col, 'text')
assert 'avg_length' in stats
assert 'max_length' in stats
class TestUnderstandData:
"""测试完整的数据理解流程。"""
def test_basic_functionality(self):
"""测试基本功能。"""
df = pd.DataFrame({
'ticket_id': [1, 2, 3, 4, 5],
'status': ['open', 'closed', 'open', 'pending', 'closed'],
'created_at': pd.date_range('2020-01-01', periods=5),
'amount': [100, 200, 150, 300, 250]
})
profile = understand_data('test.csv', data=df)
assert isinstance(profile, DataProfile)
assert profile.row_count == 5
assert profile.column_count == 4
assert len(profile.columns) == 4
assert profile.inferred_type in ['ticket', 'sales', 'user', 'unknown']
assert 0 <= profile.quality_score <= 100
assert len(profile.summary) > 0
def test_with_missing_values(self):
"""测试包含缺失值的数据。"""
df = pd.DataFrame({
'col1': [1, 2, None, 4, 5],
'col2': ['A', None, 'C', 'D', None]
})
profile = understand_data('test.csv', data=df)
assert profile.row_count == 5
# 质量分数应该因为缺失值而降低
assert profile.quality_score < 100