269 lines
8.9 KiB
Python
269 lines
8.9 KiB
Python
"""数据访问层的单元测试。"""
|
||
|
||
import pytest
|
||
import pandas as pd
|
||
import tempfile
|
||
import os
|
||
from pathlib import Path
|
||
|
||
from src.data_access import DataAccessLayer, DataLoadError
|
||
|
||
|
||
class TestDataAccessLayer:
|
||
"""数据访问层的单元测试。"""
|
||
|
||
def test_load_utf8_csv(self):
|
||
"""测试加载 UTF-8 编码的 CSV 文件。"""
|
||
# 创建临时 CSV 文件
|
||
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8') as f:
|
||
f.write('id,name,value\n')
|
||
f.write('1,测试,100\n')
|
||
f.write('2,数据,200\n')
|
||
temp_file = f.name
|
||
|
||
try:
|
||
# 加载数据
|
||
dal = DataAccessLayer.load_from_file(temp_file)
|
||
|
||
assert dal.shape == (2, 3)
|
||
assert 'id' in dal.columns
|
||
assert 'name' in dal.columns
|
||
assert 'value' in dal.columns
|
||
finally:
|
||
os.unlink(temp_file)
|
||
|
||
def test_load_gbk_csv(self):
|
||
"""测试加载 GBK 编码的 CSV 文件。"""
|
||
# 创建临时 GBK 编码的 CSV 文件
|
||
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='gbk') as f:
|
||
f.write('编号,名称,数值\n')
|
||
f.write('1,测试,100\n')
|
||
f.write('2,数据,200\n')
|
||
temp_file = f.name
|
||
|
||
try:
|
||
# 加载数据
|
||
dal = DataAccessLayer.load_from_file(temp_file)
|
||
|
||
assert dal.shape == (2, 3)
|
||
assert len(dal.columns) == 3
|
||
finally:
|
||
os.unlink(temp_file)
|
||
|
||
def test_load_empty_file(self):
|
||
"""测试加载空文件。"""
|
||
# 创建空的 CSV 文件
|
||
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8') as f:
|
||
f.write('id,name\n') # 只有表头,没有数据
|
||
temp_file = f.name
|
||
|
||
try:
|
||
# 应该抛出 DataLoadError
|
||
with pytest.raises(DataLoadError, match="为空"):
|
||
DataAccessLayer.load_from_file(temp_file)
|
||
finally:
|
||
os.unlink(temp_file)
|
||
|
||
def test_load_invalid_file(self):
|
||
"""测试加载不存在的文件。"""
|
||
with pytest.raises(DataLoadError):
|
||
DataAccessLayer.load_from_file('nonexistent_file.csv')
|
||
|
||
def test_get_profile_basic(self):
|
||
"""测试生成基本数据画像。"""
|
||
# 创建测试数据
|
||
df = pd.DataFrame({
|
||
'id': [1, 2, 3, 4, 5],
|
||
'name': ['A', 'B', 'C', 'D', 'E'],
|
||
'value': [10, 20, 30, 40, 50],
|
||
'status': ['open', 'closed', 'open', 'closed', 'open']
|
||
})
|
||
|
||
dal = DataAccessLayer(df, file_path='test.csv')
|
||
profile = dal.get_profile()
|
||
|
||
# 验证基本信息
|
||
assert profile.file_path == 'test.csv'
|
||
assert profile.row_count == 5
|
||
assert profile.column_count == 4
|
||
assert len(profile.columns) == 4
|
||
|
||
# 验证列信息
|
||
col_names = [col.name for col in profile.columns]
|
||
assert 'id' in col_names
|
||
assert 'name' in col_names
|
||
assert 'value' in col_names
|
||
assert 'status' in col_names
|
||
|
||
def test_get_profile_with_missing_values(self):
|
||
"""测试包含缺失值的数据画像。"""
|
||
df = pd.DataFrame({
|
||
'id': [1, 2, 3, 4, 5],
|
||
'value': [10, None, 30, None, 50]
|
||
})
|
||
|
||
dal = DataAccessLayer(df)
|
||
profile = dal.get_profile()
|
||
|
||
# 查找 value 列
|
||
value_col = next(col for col in profile.columns if col.name == 'value')
|
||
|
||
# 验证缺失率
|
||
assert value_col.missing_rate == 0.4 # 2/5 = 0.4
|
||
|
||
def test_column_type_inference_numeric(self):
|
||
"""测试数值类型推断。"""
|
||
df = pd.DataFrame({
|
||
'int_col': [1, 2, 3, 4, 5],
|
||
'float_col': [1.1, 2.2, 3.3, 4.4, 5.5]
|
||
})
|
||
|
||
dal = DataAccessLayer(df)
|
||
profile = dal.get_profile()
|
||
|
||
int_col = next(col for col in profile.columns if col.name == 'int_col')
|
||
float_col = next(col for col in profile.columns if col.name == 'float_col')
|
||
|
||
assert int_col.dtype == 'numeric'
|
||
assert float_col.dtype == 'numeric'
|
||
|
||
# 验证统计信息
|
||
assert 'mean' in int_col.statistics
|
||
assert 'std' in int_col.statistics
|
||
assert 'min' in int_col.statistics
|
||
assert 'max' in int_col.statistics
|
||
|
||
def test_column_type_inference_categorical(self):
|
||
"""测试分类类型推断。"""
|
||
df = pd.DataFrame({
|
||
'status': ['open', 'closed', 'open', 'closed', 'open'] * 20
|
||
})
|
||
|
||
dal = DataAccessLayer(df)
|
||
profile = dal.get_profile()
|
||
|
||
status_col = profile.columns[0]
|
||
assert status_col.dtype == 'categorical'
|
||
|
||
# 验证统计信息
|
||
assert 'top_values' in status_col.statistics
|
||
assert 'num_categories' in status_col.statistics
|
||
|
||
def test_column_type_inference_datetime(self):
|
||
"""测试日期时间类型推断。"""
|
||
df = pd.DataFrame({
|
||
'date': pd.date_range('2020-01-01', periods=10)
|
||
})
|
||
|
||
dal = DataAccessLayer(df)
|
||
profile = dal.get_profile()
|
||
|
||
date_col = profile.columns[0]
|
||
assert date_col.dtype == 'datetime'
|
||
|
||
def test_sample_values_limit(self):
|
||
"""测试示例值数量限制。"""
|
||
df = pd.DataFrame({
|
||
'id': list(range(100))
|
||
})
|
||
|
||
dal = DataAccessLayer(df)
|
||
profile = dal.get_profile()
|
||
|
||
id_col = profile.columns[0]
|
||
# 示例值应该最多5个
|
||
assert len(id_col.sample_values) <= 5
|
||
|
||
def test_sanitize_result_dataframe(self):
|
||
"""测试结果过滤 - DataFrame。"""
|
||
df = pd.DataFrame({
|
||
'id': list(range(200)),
|
||
'value': list(range(200))
|
||
})
|
||
|
||
dal = DataAccessLayer(df)
|
||
|
||
# 模拟工具返回大量数据
|
||
result = {'data': df}
|
||
sanitized = dal._sanitize_result(result)
|
||
|
||
# 验证:返回的数据应该被截断到100行
|
||
assert len(sanitized['data']) <= 100
|
||
|
||
def test_sanitize_result_series(self):
|
||
"""测试结果过滤 - Series。"""
|
||
df = pd.DataFrame({
|
||
'id': list(range(200))
|
||
})
|
||
|
||
dal = DataAccessLayer(df)
|
||
|
||
# 模拟工具返回 Series
|
||
result = {'data': df['id']}
|
||
sanitized = dal._sanitize_result(result)
|
||
|
||
# 验证:返回的数据应该被截断
|
||
assert len(sanitized['data']) <= 100
|
||
|
||
def test_large_dataset_sampling(self):
|
||
"""测试大数据集采样。"""
|
||
# 创建超过100万行的临时文件
|
||
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8') as f:
|
||
f.write('id,value\n')
|
||
# 写入少量数据用于测试(实际测试大数据集会很慢)
|
||
for i in range(1000):
|
||
f.write(f'{i},{i*10}\n')
|
||
temp_file = f.name
|
||
|
||
try:
|
||
dal = DataAccessLayer.load_from_file(temp_file)
|
||
# 验证数据被加载
|
||
assert dal.shape[0] == 1000
|
||
finally:
|
||
os.unlink(temp_file)
|
||
|
||
|
||
class TestDataAccessLayerIntegration:
|
||
"""数据访问层的集成测试。"""
|
||
|
||
def test_end_to_end_workflow(self):
|
||
"""测试端到端工作流程。"""
|
||
# 创建测试数据
|
||
df = pd.DataFrame({
|
||
'id': [1, 2, 3, 4, 5],
|
||
'status': ['open', 'closed', 'open', 'closed', 'pending'],
|
||
'value': [100, 200, 150, 300, 250],
|
||
'created_at': pd.date_range('2020-01-01', periods=5)
|
||
})
|
||
|
||
# 保存到临时文件
|
||
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8') as f:
|
||
df.to_csv(f.name, index=False)
|
||
temp_file = f.name
|
||
|
||
try:
|
||
# 1. 加载数据
|
||
dal = DataAccessLayer.load_from_file(temp_file)
|
||
|
||
# 2. 生成数据画像
|
||
profile = dal.get_profile()
|
||
|
||
# 3. 验证数据画像
|
||
assert profile.row_count == 5
|
||
assert profile.column_count == 4
|
||
|
||
# 4. 验证列类型推断
|
||
col_types = {col.name: col.dtype for col in profile.columns}
|
||
assert col_types['id'] == 'numeric'
|
||
assert col_types['status'] == 'categorical'
|
||
assert col_types['value'] == 'numeric'
|
||
assert col_types['created_at'] == 'datetime'
|
||
|
||
# 5. 验证统计信息
|
||
value_col = next(col for col in profile.columns if col.name == 'value')
|
||
assert 'mean' in value_col.statistics
|
||
assert value_col.statistics['mean'] == 200.0
|
||
|
||
finally:
|
||
os.unlink(temp_file)
|