"""数据理解引擎的单元测试。""" import pytest import pandas as pd import numpy as np from datetime import datetime, timedelta from src.engines.data_understanding import ( generate_basic_stats, understand_data, _infer_column_type, _infer_data_type, _identify_key_fields, _evaluate_data_quality, _get_sample_values, _generate_column_statistics ) from src.models import DataProfile, ColumnInfo class TestGenerateBasicStats: """测试基础统计生成。""" def test_basic_functionality(self): """测试基本功能。""" df = pd.DataFrame({ 'id': [1, 2, 3, 4, 5], 'name': ['A', 'B', 'C', 'D', 'E'], 'value': [10.5, 20.3, 30.1, 40.8, 50.2] }) stats = generate_basic_stats(df, 'test.csv') assert stats['file_path'] == 'test.csv' assert stats['row_count'] == 5 assert stats['column_count'] == 3 assert len(stats['columns']) == 3 def test_empty_dataframe(self): """测试空 DataFrame。""" df = pd.DataFrame() stats = generate_basic_stats(df, 'empty.csv') assert stats['row_count'] == 0 assert stats['column_count'] == 0 assert len(stats['columns']) == 0 class TestInferColumnType: """测试列类型推断。""" def test_numeric_column(self): """测试数值列。""" col = pd.Series([1, 2, 3, 4, 5]) dtype = _infer_column_type(col) assert dtype == 'numeric' def test_categorical_column(self): """测试分类列。""" col = pd.Series(['A', 'B', 'A', 'C', 'B', 'A', 'A', 'B', 'C', 'A']) # 10个值,3个唯一值,比例30% dtype = _infer_column_type(col) assert dtype == 'categorical' def test_datetime_column(self): """测试日期时间列。""" col = pd.Series(pd.date_range('2020-01-01', periods=5)) dtype = _infer_column_type(col) assert dtype == 'datetime' def test_text_column(self): """测试文本列(唯一值多)。""" col = pd.Series([f'text_{i}' for i in range(100)]) dtype = _infer_column_type(col) assert dtype == 'text' class TestInferDataType: """测试数据类型推断。""" def test_ticket_data(self): """测试工单数据识别。""" columns = [ ColumnInfo(name='ticket_id', dtype='text', missing_rate=0.0, unique_count=100), ColumnInfo(name='status', dtype='categorical', missing_rate=0.0, unique_count=5), ColumnInfo(name='created_at', dtype='datetime', missing_rate=0.0, unique_count=100), ] data_type = _infer_data_type(columns) assert data_type == 'ticket' def test_sales_data(self): """测试销售数据识别。""" columns = [ ColumnInfo(name='order_id', dtype='text', missing_rate=0.0, unique_count=100), ColumnInfo(name='product', dtype='categorical', missing_rate=0.0, unique_count=10), ColumnInfo(name='amount', dtype='numeric', missing_rate=0.0, unique_count=50), ] data_type = _infer_data_type(columns) assert data_type == 'sales' def test_user_data(self): """测试用户数据识别。""" columns = [ ColumnInfo(name='user_id', dtype='text', missing_rate=0.0, unique_count=100), ColumnInfo(name='name', dtype='text', missing_rate=0.0, unique_count=100), ColumnInfo(name='email', dtype='text', missing_rate=0.0, unique_count=100), ] data_type = _infer_data_type(columns) assert data_type == 'user' def test_unknown_data(self): """测试未知数据类型。""" columns = [ ColumnInfo(name='col1', dtype='numeric', missing_rate=0.0, unique_count=100), ColumnInfo(name='col2', dtype='numeric', missing_rate=0.0, unique_count=100), ] data_type = _infer_data_type(columns) assert data_type == 'unknown' class TestIdentifyKeyFields: """测试关键字段识别。""" def test_time_fields(self): """测试时间字段识别。""" columns = [ ColumnInfo(name='created_at', dtype='datetime', missing_rate=0.0, unique_count=100), ColumnInfo(name='closed_at', dtype='datetime', missing_rate=0.0, unique_count=100), ] key_fields = _identify_key_fields(columns) assert 'created_at' in key_fields assert 'closed_at' in key_fields assert '创建时间' in key_fields['created_at'] assert '完成时间' in key_fields['closed_at'] def test_status_field(self): """测试状态字段识别。""" columns = [ ColumnInfo(name='status', dtype='categorical', missing_rate=0.0, unique_count=5), ] key_fields = _identify_key_fields(columns) assert 'status' in key_fields assert '状态' in key_fields['status'] def test_id_field(self): """测试ID字段识别。""" columns = [ ColumnInfo(name='ticket_id', dtype='text', missing_rate=0.0, unique_count=100), ] key_fields = _identify_key_fields(columns) assert 'ticket_id' in key_fields assert '标识符' in key_fields['ticket_id'] class TestEvaluateDataQuality: """测试数据质量评估。""" def test_high_quality_data(self): """测试高质量数据。""" columns = [ ColumnInfo(name='col1', dtype='numeric', missing_rate=0.0, unique_count=100), ColumnInfo(name='col2', dtype='categorical', missing_rate=0.0, unique_count=5), ] quality_score = _evaluate_data_quality(columns, row_count=100) assert quality_score >= 80 def test_low_quality_data(self): """测试低质量数据(高缺失率)。""" columns = [ ColumnInfo(name='col1', dtype='numeric', missing_rate=0.8, unique_count=20), ColumnInfo(name='col2', dtype='categorical', missing_rate=0.9, unique_count=2), ] quality_score = _evaluate_data_quality(columns, row_count=100) assert quality_score < 50 def test_empty_data(self): """测试空数据。""" columns = [] quality_score = _evaluate_data_quality(columns, row_count=0) assert quality_score == 0.0 class TestGetSampleValues: """测试示例值获取。""" def test_basic_functionality(self): """测试基本功能。""" col = pd.Series([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) samples = _get_sample_values(col, max_samples=5) assert len(samples) <= 5 assert all(isinstance(s, (int, float)) for s in samples) def test_with_null_values(self): """测试包含空值的情况。""" col = pd.Series([1, 2, None, 4, None, 6]) samples = _get_sample_values(col, max_samples=5) assert len(samples) <= 4 # 排除了空值 def test_datetime_values(self): """测试日期时间值。""" col = pd.Series(pd.date_range('2020-01-01', periods=5)) samples = _get_sample_values(col, max_samples=3) assert len(samples) <= 3 assert all(isinstance(s, str) for s in samples) class TestGenerateColumnStatistics: """测试列统计信息生成。""" def test_numeric_statistics(self): """测试数值列统计。""" col = pd.Series([1, 2, 3, 4, 5]) stats = _generate_column_statistics(col, 'numeric') assert 'mean' in stats assert 'median' in stats assert 'std' in stats assert 'min' in stats assert 'max' in stats assert stats['mean'] == 3.0 assert stats['min'] == 1.0 assert stats['max'] == 5.0 def test_categorical_statistics(self): """测试分类列统计。""" col = pd.Series(['A', 'B', 'A', 'C', 'A']) stats = _generate_column_statistics(col, 'categorical') assert 'most_common' in stats assert 'most_common_count' in stats assert stats['most_common'] == 'A' assert stats['most_common_count'] == 3 def test_datetime_statistics(self): """测试日期时间列统计。""" col = pd.Series(pd.date_range('2020-01-01', periods=10)) stats = _generate_column_statistics(col, 'datetime') assert 'min_date' in stats assert 'max_date' in stats assert 'date_range_days' in stats def test_text_statistics(self): """测试文本列统计。""" col = pd.Series(['hello', 'world', 'test']) stats = _generate_column_statistics(col, 'text') assert 'avg_length' in stats assert 'max_length' in stats class TestUnderstandData: """测试完整的数据理解流程。""" def test_basic_functionality(self): """测试基本功能。""" df = pd.DataFrame({ 'ticket_id': [1, 2, 3, 4, 5], 'status': ['open', 'closed', 'open', 'pending', 'closed'], 'created_at': pd.date_range('2020-01-01', periods=5), 'amount': [100, 200, 150, 300, 250] }) profile = understand_data('test.csv', data=df) assert isinstance(profile, DataProfile) assert profile.row_count == 5 assert profile.column_count == 4 assert len(profile.columns) == 4 assert profile.inferred_type in ['ticket', 'sales', 'user', 'unknown'] assert 0 <= profile.quality_score <= 100 assert len(profile.summary) > 0 def test_with_missing_values(self): """测试包含缺失值的数据。""" df = pd.DataFrame({ 'col1': [1, 2, None, 4, 5], 'col2': ['A', None, 'C', 'D', None] }) profile = understand_data('test.csv', data=df) assert profile.row_count == 5 # 质量分数应该因为缺失值而降低 assert profile.quality_score < 100