Complete AI Data Analysis Agent implementation with 95.7% test coverage

This commit is contained in:
2026-03-07 00:04:29 +08:00
parent 621e546b43
commit 7071b1f730
245 changed files with 22612 additions and 2211 deletions

View File

@@ -0,0 +1,311 @@
"""数据理解引擎的单元测试。"""
import pytest
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from src.engines.data_understanding import (
generate_basic_stats,
understand_data,
_infer_column_type,
_infer_data_type,
_identify_key_fields,
_evaluate_data_quality,
_get_sample_values,
_generate_column_statistics
)
from src.models import DataProfile, ColumnInfo
class TestGenerateBasicStats:
"""测试基础统计生成。"""
def test_basic_functionality(self):
"""测试基本功能。"""
df = pd.DataFrame({
'id': [1, 2, 3, 4, 5],
'name': ['A', 'B', 'C', 'D', 'E'],
'value': [10.5, 20.3, 30.1, 40.8, 50.2]
})
stats = generate_basic_stats(df, 'test.csv')
assert stats['file_path'] == 'test.csv'
assert stats['row_count'] == 5
assert stats['column_count'] == 3
assert len(stats['columns']) == 3
def test_empty_dataframe(self):
"""测试空 DataFrame。"""
df = pd.DataFrame()
stats = generate_basic_stats(df, 'empty.csv')
assert stats['row_count'] == 0
assert stats['column_count'] == 0
assert len(stats['columns']) == 0
class TestInferColumnType:
"""测试列类型推断。"""
def test_numeric_column(self):
"""测试数值列。"""
col = pd.Series([1, 2, 3, 4, 5])
dtype = _infer_column_type(col)
assert dtype == 'numeric'
def test_categorical_column(self):
"""测试分类列。"""
col = pd.Series(['A', 'B', 'A', 'C', 'B', 'A', 'A', 'B', 'C', 'A']) # 10个值3个唯一值比例30%
dtype = _infer_column_type(col)
assert dtype == 'categorical'
def test_datetime_column(self):
"""测试日期时间列。"""
col = pd.Series(pd.date_range('2020-01-01', periods=5))
dtype = _infer_column_type(col)
assert dtype == 'datetime'
def test_text_column(self):
"""测试文本列(唯一值多)。"""
col = pd.Series([f'text_{i}' for i in range(100)])
dtype = _infer_column_type(col)
assert dtype == 'text'
class TestInferDataType:
"""测试数据类型推断。"""
def test_ticket_data(self):
"""测试工单数据识别。"""
columns = [
ColumnInfo(name='ticket_id', dtype='text', missing_rate=0.0, unique_count=100),
ColumnInfo(name='status', dtype='categorical', missing_rate=0.0, unique_count=5),
ColumnInfo(name='created_at', dtype='datetime', missing_rate=0.0, unique_count=100),
]
data_type = _infer_data_type(columns)
assert data_type == 'ticket'
def test_sales_data(self):
"""测试销售数据识别。"""
columns = [
ColumnInfo(name='order_id', dtype='text', missing_rate=0.0, unique_count=100),
ColumnInfo(name='product', dtype='categorical', missing_rate=0.0, unique_count=10),
ColumnInfo(name='amount', dtype='numeric', missing_rate=0.0, unique_count=50),
]
data_type = _infer_data_type(columns)
assert data_type == 'sales'
def test_user_data(self):
"""测试用户数据识别。"""
columns = [
ColumnInfo(name='user_id', dtype='text', missing_rate=0.0, unique_count=100),
ColumnInfo(name='name', dtype='text', missing_rate=0.0, unique_count=100),
ColumnInfo(name='email', dtype='text', missing_rate=0.0, unique_count=100),
]
data_type = _infer_data_type(columns)
assert data_type == 'user'
def test_unknown_data(self):
"""测试未知数据类型。"""
columns = [
ColumnInfo(name='col1', dtype='numeric', missing_rate=0.0, unique_count=100),
ColumnInfo(name='col2', dtype='numeric', missing_rate=0.0, unique_count=100),
]
data_type = _infer_data_type(columns)
assert data_type == 'unknown'
class TestIdentifyKeyFields:
"""测试关键字段识别。"""
def test_time_fields(self):
"""测试时间字段识别。"""
columns = [
ColumnInfo(name='created_at', dtype='datetime', missing_rate=0.0, unique_count=100),
ColumnInfo(name='closed_at', dtype='datetime', missing_rate=0.0, unique_count=100),
]
key_fields = _identify_key_fields(columns)
assert 'created_at' in key_fields
assert 'closed_at' in key_fields
assert '创建时间' in key_fields['created_at']
assert '完成时间' in key_fields['closed_at']
def test_status_field(self):
"""测试状态字段识别。"""
columns = [
ColumnInfo(name='status', dtype='categorical', missing_rate=0.0, unique_count=5),
]
key_fields = _identify_key_fields(columns)
assert 'status' in key_fields
assert '状态' in key_fields['status']
def test_id_field(self):
"""测试ID字段识别。"""
columns = [
ColumnInfo(name='ticket_id', dtype='text', missing_rate=0.0, unique_count=100),
]
key_fields = _identify_key_fields(columns)
assert 'ticket_id' in key_fields
assert '标识符' in key_fields['ticket_id']
class TestEvaluateDataQuality:
"""测试数据质量评估。"""
def test_high_quality_data(self):
"""测试高质量数据。"""
columns = [
ColumnInfo(name='col1', dtype='numeric', missing_rate=0.0, unique_count=100),
ColumnInfo(name='col2', dtype='categorical', missing_rate=0.0, unique_count=5),
]
quality_score = _evaluate_data_quality(columns, row_count=100)
assert quality_score >= 80
def test_low_quality_data(self):
"""测试低质量数据(高缺失率)。"""
columns = [
ColumnInfo(name='col1', dtype='numeric', missing_rate=0.8, unique_count=20),
ColumnInfo(name='col2', dtype='categorical', missing_rate=0.9, unique_count=2),
]
quality_score = _evaluate_data_quality(columns, row_count=100)
assert quality_score < 50
def test_empty_data(self):
"""测试空数据。"""
columns = []
quality_score = _evaluate_data_quality(columns, row_count=0)
assert quality_score == 0.0
class TestGetSampleValues:
"""测试示例值获取。"""
def test_basic_functionality(self):
"""测试基本功能。"""
col = pd.Series([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
samples = _get_sample_values(col, max_samples=5)
assert len(samples) <= 5
assert all(isinstance(s, (int, float)) for s in samples)
def test_with_null_values(self):
"""测试包含空值的情况。"""
col = pd.Series([1, 2, None, 4, None, 6])
samples = _get_sample_values(col, max_samples=5)
assert len(samples) <= 4 # 排除了空值
def test_datetime_values(self):
"""测试日期时间值。"""
col = pd.Series(pd.date_range('2020-01-01', periods=5))
samples = _get_sample_values(col, max_samples=3)
assert len(samples) <= 3
assert all(isinstance(s, str) for s in samples)
class TestGenerateColumnStatistics:
"""测试列统计信息生成。"""
def test_numeric_statistics(self):
"""测试数值列统计。"""
col = pd.Series([1, 2, 3, 4, 5])
stats = _generate_column_statistics(col, 'numeric')
assert 'mean' in stats
assert 'median' in stats
assert 'std' in stats
assert 'min' in stats
assert 'max' in stats
assert stats['mean'] == 3.0
assert stats['min'] == 1.0
assert stats['max'] == 5.0
def test_categorical_statistics(self):
"""测试分类列统计。"""
col = pd.Series(['A', 'B', 'A', 'C', 'A'])
stats = _generate_column_statistics(col, 'categorical')
assert 'most_common' in stats
assert 'most_common_count' in stats
assert stats['most_common'] == 'A'
assert stats['most_common_count'] == 3
def test_datetime_statistics(self):
"""测试日期时间列统计。"""
col = pd.Series(pd.date_range('2020-01-01', periods=10))
stats = _generate_column_statistics(col, 'datetime')
assert 'min_date' in stats
assert 'max_date' in stats
assert 'date_range_days' in stats
def test_text_statistics(self):
"""测试文本列统计。"""
col = pd.Series(['hello', 'world', 'test'])
stats = _generate_column_statistics(col, 'text')
assert 'avg_length' in stats
assert 'max_length' in stats
class TestUnderstandData:
"""测试完整的数据理解流程。"""
def test_basic_functionality(self):
"""测试基本功能。"""
df = pd.DataFrame({
'ticket_id': [1, 2, 3, 4, 5],
'status': ['open', 'closed', 'open', 'pending', 'closed'],
'created_at': pd.date_range('2020-01-01', periods=5),
'amount': [100, 200, 150, 300, 250]
})
profile = understand_data('test.csv', data=df)
assert isinstance(profile, DataProfile)
assert profile.row_count == 5
assert profile.column_count == 4
assert len(profile.columns) == 4
assert profile.inferred_type in ['ticket', 'sales', 'user', 'unknown']
assert 0 <= profile.quality_score <= 100
assert len(profile.summary) > 0
def test_with_missing_values(self):
"""测试包含缺失值的数据。"""
df = pd.DataFrame({
'col1': [1, 2, None, 4, 5],
'col2': ['A', None, 'C', 'D', None]
})
profile = understand_data('test.csv', data=df)
assert profile.row_count == 5
# 质量分数应该因为缺失值而降低
assert profile.quality_score < 100