Complete AI Data Analysis Agent implementation with 95.7% test coverage

This commit is contained in:
2026-03-07 00:04:29 +08:00
parent 621e546b43
commit 7071b1f730
245 changed files with 22612 additions and 2211 deletions

268
tests/test_data_access.py Normal file
View File

@@ -0,0 +1,268 @@
"""数据访问层的单元测试。"""
import pytest
import pandas as pd
import tempfile
import os
from pathlib import Path
from src.data_access import DataAccessLayer, DataLoadError
class TestDataAccessLayer:
"""数据访问层的单元测试。"""
def test_load_utf8_csv(self):
"""测试加载 UTF-8 编码的 CSV 文件。"""
# 创建临时 CSV 文件
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8') as f:
f.write('id,name,value\n')
f.write('1,测试,100\n')
f.write('2,数据,200\n')
temp_file = f.name
try:
# 加载数据
dal = DataAccessLayer.load_from_file(temp_file)
assert dal.shape == (2, 3)
assert 'id' in dal.columns
assert 'name' in dal.columns
assert 'value' in dal.columns
finally:
os.unlink(temp_file)
def test_load_gbk_csv(self):
"""测试加载 GBK 编码的 CSV 文件。"""
# 创建临时 GBK 编码的 CSV 文件
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='gbk') as f:
f.write('编号,名称,数值\n')
f.write('1,测试,100\n')
f.write('2,数据,200\n')
temp_file = f.name
try:
# 加载数据
dal = DataAccessLayer.load_from_file(temp_file)
assert dal.shape == (2, 3)
assert len(dal.columns) == 3
finally:
os.unlink(temp_file)
def test_load_empty_file(self):
"""测试加载空文件。"""
# 创建空的 CSV 文件
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8') as f:
f.write('id,name\n') # 只有表头,没有数据
temp_file = f.name
try:
# 应该抛出 DataLoadError
with pytest.raises(DataLoadError, match="为空"):
DataAccessLayer.load_from_file(temp_file)
finally:
os.unlink(temp_file)
def test_load_invalid_file(self):
"""测试加载不存在的文件。"""
with pytest.raises(DataLoadError):
DataAccessLayer.load_from_file('nonexistent_file.csv')
def test_get_profile_basic(self):
"""测试生成基本数据画像。"""
# 创建测试数据
df = pd.DataFrame({
'id': [1, 2, 3, 4, 5],
'name': ['A', 'B', 'C', 'D', 'E'],
'value': [10, 20, 30, 40, 50],
'status': ['open', 'closed', 'open', 'closed', 'open']
})
dal = DataAccessLayer(df, file_path='test.csv')
profile = dal.get_profile()
# 验证基本信息
assert profile.file_path == 'test.csv'
assert profile.row_count == 5
assert profile.column_count == 4
assert len(profile.columns) == 4
# 验证列信息
col_names = [col.name for col in profile.columns]
assert 'id' in col_names
assert 'name' in col_names
assert 'value' in col_names
assert 'status' in col_names
def test_get_profile_with_missing_values(self):
"""测试包含缺失值的数据画像。"""
df = pd.DataFrame({
'id': [1, 2, 3, 4, 5],
'value': [10, None, 30, None, 50]
})
dal = DataAccessLayer(df)
profile = dal.get_profile()
# 查找 value 列
value_col = next(col for col in profile.columns if col.name == 'value')
# 验证缺失率
assert value_col.missing_rate == 0.4 # 2/5 = 0.4
def test_column_type_inference_numeric(self):
"""测试数值类型推断。"""
df = pd.DataFrame({
'int_col': [1, 2, 3, 4, 5],
'float_col': [1.1, 2.2, 3.3, 4.4, 5.5]
})
dal = DataAccessLayer(df)
profile = dal.get_profile()
int_col = next(col for col in profile.columns if col.name == 'int_col')
float_col = next(col for col in profile.columns if col.name == 'float_col')
assert int_col.dtype == 'numeric'
assert float_col.dtype == 'numeric'
# 验证统计信息
assert 'mean' in int_col.statistics
assert 'std' in int_col.statistics
assert 'min' in int_col.statistics
assert 'max' in int_col.statistics
def test_column_type_inference_categorical(self):
"""测试分类类型推断。"""
df = pd.DataFrame({
'status': ['open', 'closed', 'open', 'closed', 'open'] * 20
})
dal = DataAccessLayer(df)
profile = dal.get_profile()
status_col = profile.columns[0]
assert status_col.dtype == 'categorical'
# 验证统计信息
assert 'top_values' in status_col.statistics
assert 'num_categories' in status_col.statistics
def test_column_type_inference_datetime(self):
"""测试日期时间类型推断。"""
df = pd.DataFrame({
'date': pd.date_range('2020-01-01', periods=10)
})
dal = DataAccessLayer(df)
profile = dal.get_profile()
date_col = profile.columns[0]
assert date_col.dtype == 'datetime'
def test_sample_values_limit(self):
"""测试示例值数量限制。"""
df = pd.DataFrame({
'id': list(range(100))
})
dal = DataAccessLayer(df)
profile = dal.get_profile()
id_col = profile.columns[0]
# 示例值应该最多5个
assert len(id_col.sample_values) <= 5
def test_sanitize_result_dataframe(self):
"""测试结果过滤 - DataFrame。"""
df = pd.DataFrame({
'id': list(range(200)),
'value': list(range(200))
})
dal = DataAccessLayer(df)
# 模拟工具返回大量数据
result = {'data': df}
sanitized = dal._sanitize_result(result)
# 验证返回的数据应该被截断到100行
assert len(sanitized['data']) <= 100
def test_sanitize_result_series(self):
"""测试结果过滤 - Series。"""
df = pd.DataFrame({
'id': list(range(200))
})
dal = DataAccessLayer(df)
# 模拟工具返回 Series
result = {'data': df['id']}
sanitized = dal._sanitize_result(result)
# 验证:返回的数据应该被截断
assert len(sanitized['data']) <= 100
def test_large_dataset_sampling(self):
"""测试大数据集采样。"""
# 创建超过100万行的临时文件
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8') as f:
f.write('id,value\n')
# 写入少量数据用于测试(实际测试大数据集会很慢)
for i in range(1000):
f.write(f'{i},{i*10}\n')
temp_file = f.name
try:
dal = DataAccessLayer.load_from_file(temp_file)
# 验证数据被加载
assert dal.shape[0] == 1000
finally:
os.unlink(temp_file)
class TestDataAccessLayerIntegration:
"""数据访问层的集成测试。"""
def test_end_to_end_workflow(self):
"""测试端到端工作流程。"""
# 创建测试数据
df = pd.DataFrame({
'id': [1, 2, 3, 4, 5],
'status': ['open', 'closed', 'open', 'closed', 'pending'],
'value': [100, 200, 150, 300, 250],
'created_at': pd.date_range('2020-01-01', periods=5)
})
# 保存到临时文件
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8') as f:
df.to_csv(f.name, index=False)
temp_file = f.name
try:
# 1. 加载数据
dal = DataAccessLayer.load_from_file(temp_file)
# 2. 生成数据画像
profile = dal.get_profile()
# 3. 验证数据画像
assert profile.row_count == 5
assert profile.column_count == 4
# 4. 验证列类型推断
col_types = {col.name: col.dtype for col in profile.columns}
assert col_types['id'] == 'numeric'
assert col_types['status'] == 'categorical'
assert col_types['value'] == 'numeric'
assert col_types['created_at'] == 'datetime'
# 5. 验证统计信息
value_col = next(col for col in profile.columns if col.name == 'value')
assert 'mean' in value_col.statistics
assert value_col.statistics['mean'] == 200.0
finally:
os.unlink(temp_file)