Complete AI Data Analysis Agent implementation with 95.7% test coverage
This commit is contained in:
268
tests/test_data_access.py
Normal file
268
tests/test_data_access.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""数据访问层的单元测试。"""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import tempfile
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from src.data_access import DataAccessLayer, DataLoadError
|
||||
|
||||
|
||||
class TestDataAccessLayer:
|
||||
"""数据访问层的单元测试。"""
|
||||
|
||||
def test_load_utf8_csv(self):
|
||||
"""测试加载 UTF-8 编码的 CSV 文件。"""
|
||||
# 创建临时 CSV 文件
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8') as f:
|
||||
f.write('id,name,value\n')
|
||||
f.write('1,测试,100\n')
|
||||
f.write('2,数据,200\n')
|
||||
temp_file = f.name
|
||||
|
||||
try:
|
||||
# 加载数据
|
||||
dal = DataAccessLayer.load_from_file(temp_file)
|
||||
|
||||
assert dal.shape == (2, 3)
|
||||
assert 'id' in dal.columns
|
||||
assert 'name' in dal.columns
|
||||
assert 'value' in dal.columns
|
||||
finally:
|
||||
os.unlink(temp_file)
|
||||
|
||||
def test_load_gbk_csv(self):
|
||||
"""测试加载 GBK 编码的 CSV 文件。"""
|
||||
# 创建临时 GBK 编码的 CSV 文件
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='gbk') as f:
|
||||
f.write('编号,名称,数值\n')
|
||||
f.write('1,测试,100\n')
|
||||
f.write('2,数据,200\n')
|
||||
temp_file = f.name
|
||||
|
||||
try:
|
||||
# 加载数据
|
||||
dal = DataAccessLayer.load_from_file(temp_file)
|
||||
|
||||
assert dal.shape == (2, 3)
|
||||
assert len(dal.columns) == 3
|
||||
finally:
|
||||
os.unlink(temp_file)
|
||||
|
||||
def test_load_empty_file(self):
|
||||
"""测试加载空文件。"""
|
||||
# 创建空的 CSV 文件
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8') as f:
|
||||
f.write('id,name\n') # 只有表头,没有数据
|
||||
temp_file = f.name
|
||||
|
||||
try:
|
||||
# 应该抛出 DataLoadError
|
||||
with pytest.raises(DataLoadError, match="为空"):
|
||||
DataAccessLayer.load_from_file(temp_file)
|
||||
finally:
|
||||
os.unlink(temp_file)
|
||||
|
||||
def test_load_invalid_file(self):
|
||||
"""测试加载不存在的文件。"""
|
||||
with pytest.raises(DataLoadError):
|
||||
DataAccessLayer.load_from_file('nonexistent_file.csv')
|
||||
|
||||
def test_get_profile_basic(self):
|
||||
"""测试生成基本数据画像。"""
|
||||
# 创建测试数据
|
||||
df = pd.DataFrame({
|
||||
'id': [1, 2, 3, 4, 5],
|
||||
'name': ['A', 'B', 'C', 'D', 'E'],
|
||||
'value': [10, 20, 30, 40, 50],
|
||||
'status': ['open', 'closed', 'open', 'closed', 'open']
|
||||
})
|
||||
|
||||
dal = DataAccessLayer(df, file_path='test.csv')
|
||||
profile = dal.get_profile()
|
||||
|
||||
# 验证基本信息
|
||||
assert profile.file_path == 'test.csv'
|
||||
assert profile.row_count == 5
|
||||
assert profile.column_count == 4
|
||||
assert len(profile.columns) == 4
|
||||
|
||||
# 验证列信息
|
||||
col_names = [col.name for col in profile.columns]
|
||||
assert 'id' in col_names
|
||||
assert 'name' in col_names
|
||||
assert 'value' in col_names
|
||||
assert 'status' in col_names
|
||||
|
||||
def test_get_profile_with_missing_values(self):
|
||||
"""测试包含缺失值的数据画像。"""
|
||||
df = pd.DataFrame({
|
||||
'id': [1, 2, 3, 4, 5],
|
||||
'value': [10, None, 30, None, 50]
|
||||
})
|
||||
|
||||
dal = DataAccessLayer(df)
|
||||
profile = dal.get_profile()
|
||||
|
||||
# 查找 value 列
|
||||
value_col = next(col for col in profile.columns if col.name == 'value')
|
||||
|
||||
# 验证缺失率
|
||||
assert value_col.missing_rate == 0.4 # 2/5 = 0.4
|
||||
|
||||
def test_column_type_inference_numeric(self):
|
||||
"""测试数值类型推断。"""
|
||||
df = pd.DataFrame({
|
||||
'int_col': [1, 2, 3, 4, 5],
|
||||
'float_col': [1.1, 2.2, 3.3, 4.4, 5.5]
|
||||
})
|
||||
|
||||
dal = DataAccessLayer(df)
|
||||
profile = dal.get_profile()
|
||||
|
||||
int_col = next(col for col in profile.columns if col.name == 'int_col')
|
||||
float_col = next(col for col in profile.columns if col.name == 'float_col')
|
||||
|
||||
assert int_col.dtype == 'numeric'
|
||||
assert float_col.dtype == 'numeric'
|
||||
|
||||
# 验证统计信息
|
||||
assert 'mean' in int_col.statistics
|
||||
assert 'std' in int_col.statistics
|
||||
assert 'min' in int_col.statistics
|
||||
assert 'max' in int_col.statistics
|
||||
|
||||
def test_column_type_inference_categorical(self):
|
||||
"""测试分类类型推断。"""
|
||||
df = pd.DataFrame({
|
||||
'status': ['open', 'closed', 'open', 'closed', 'open'] * 20
|
||||
})
|
||||
|
||||
dal = DataAccessLayer(df)
|
||||
profile = dal.get_profile()
|
||||
|
||||
status_col = profile.columns[0]
|
||||
assert status_col.dtype == 'categorical'
|
||||
|
||||
# 验证统计信息
|
||||
assert 'top_values' in status_col.statistics
|
||||
assert 'num_categories' in status_col.statistics
|
||||
|
||||
def test_column_type_inference_datetime(self):
|
||||
"""测试日期时间类型推断。"""
|
||||
df = pd.DataFrame({
|
||||
'date': pd.date_range('2020-01-01', periods=10)
|
||||
})
|
||||
|
||||
dal = DataAccessLayer(df)
|
||||
profile = dal.get_profile()
|
||||
|
||||
date_col = profile.columns[0]
|
||||
assert date_col.dtype == 'datetime'
|
||||
|
||||
def test_sample_values_limit(self):
|
||||
"""测试示例值数量限制。"""
|
||||
df = pd.DataFrame({
|
||||
'id': list(range(100))
|
||||
})
|
||||
|
||||
dal = DataAccessLayer(df)
|
||||
profile = dal.get_profile()
|
||||
|
||||
id_col = profile.columns[0]
|
||||
# 示例值应该最多5个
|
||||
assert len(id_col.sample_values) <= 5
|
||||
|
||||
def test_sanitize_result_dataframe(self):
|
||||
"""测试结果过滤 - DataFrame。"""
|
||||
df = pd.DataFrame({
|
||||
'id': list(range(200)),
|
||||
'value': list(range(200))
|
||||
})
|
||||
|
||||
dal = DataAccessLayer(df)
|
||||
|
||||
# 模拟工具返回大量数据
|
||||
result = {'data': df}
|
||||
sanitized = dal._sanitize_result(result)
|
||||
|
||||
# 验证:返回的数据应该被截断到100行
|
||||
assert len(sanitized['data']) <= 100
|
||||
|
||||
def test_sanitize_result_series(self):
|
||||
"""测试结果过滤 - Series。"""
|
||||
df = pd.DataFrame({
|
||||
'id': list(range(200))
|
||||
})
|
||||
|
||||
dal = DataAccessLayer(df)
|
||||
|
||||
# 模拟工具返回 Series
|
||||
result = {'data': df['id']}
|
||||
sanitized = dal._sanitize_result(result)
|
||||
|
||||
# 验证:返回的数据应该被截断
|
||||
assert len(sanitized['data']) <= 100
|
||||
|
||||
def test_large_dataset_sampling(self):
|
||||
"""测试大数据集采样。"""
|
||||
# 创建超过100万行的临时文件
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8') as f:
|
||||
f.write('id,value\n')
|
||||
# 写入少量数据用于测试(实际测试大数据集会很慢)
|
||||
for i in range(1000):
|
||||
f.write(f'{i},{i*10}\n')
|
||||
temp_file = f.name
|
||||
|
||||
try:
|
||||
dal = DataAccessLayer.load_from_file(temp_file)
|
||||
# 验证数据被加载
|
||||
assert dal.shape[0] == 1000
|
||||
finally:
|
||||
os.unlink(temp_file)
|
||||
|
||||
|
||||
class TestDataAccessLayerIntegration:
|
||||
"""数据访问层的集成测试。"""
|
||||
|
||||
def test_end_to_end_workflow(self):
|
||||
"""测试端到端工作流程。"""
|
||||
# 创建测试数据
|
||||
df = pd.DataFrame({
|
||||
'id': [1, 2, 3, 4, 5],
|
||||
'status': ['open', 'closed', 'open', 'closed', 'pending'],
|
||||
'value': [100, 200, 150, 300, 250],
|
||||
'created_at': pd.date_range('2020-01-01', periods=5)
|
||||
})
|
||||
|
||||
# 保存到临时文件
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8') as f:
|
||||
df.to_csv(f.name, index=False)
|
||||
temp_file = f.name
|
||||
|
||||
try:
|
||||
# 1. 加载数据
|
||||
dal = DataAccessLayer.load_from_file(temp_file)
|
||||
|
||||
# 2. 生成数据画像
|
||||
profile = dal.get_profile()
|
||||
|
||||
# 3. 验证数据画像
|
||||
assert profile.row_count == 5
|
||||
assert profile.column_count == 4
|
||||
|
||||
# 4. 验证列类型推断
|
||||
col_types = {col.name: col.dtype for col in profile.columns}
|
||||
assert col_types['id'] == 'numeric'
|
||||
assert col_types['status'] == 'categorical'
|
||||
assert col_types['value'] == 'numeric'
|
||||
assert col_types['created_at'] == 'datetime'
|
||||
|
||||
# 5. 验证统计信息
|
||||
value_col = next(col for col in profile.columns if col.name == 'value')
|
||||
assert 'mean' in value_col.statistics
|
||||
assert value_col.statistics['mean'] == 200.0
|
||||
|
||||
finally:
|
||||
os.unlink(temp_file)
|
||||
Reference in New Issue
Block a user