Files
vibe_data_ana/tests/test_data_access.py

269 lines
8.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""数据访问层的单元测试。"""
import pytest
import pandas as pd
import tempfile
import os
from pathlib import Path
from src.data_access import DataAccessLayer, DataLoadError
class TestDataAccessLayer:
"""数据访问层的单元测试。"""
def test_load_utf8_csv(self):
"""测试加载 UTF-8 编码的 CSV 文件。"""
# 创建临时 CSV 文件
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8') as f:
f.write('id,name,value\n')
f.write('1,测试,100\n')
f.write('2,数据,200\n')
temp_file = f.name
try:
# 加载数据
dal = DataAccessLayer.load_from_file(temp_file)
assert dal.shape == (2, 3)
assert 'id' in dal.columns
assert 'name' in dal.columns
assert 'value' in dal.columns
finally:
os.unlink(temp_file)
def test_load_gbk_csv(self):
"""测试加载 GBK 编码的 CSV 文件。"""
# 创建临时 GBK 编码的 CSV 文件
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='gbk') as f:
f.write('编号,名称,数值\n')
f.write('1,测试,100\n')
f.write('2,数据,200\n')
temp_file = f.name
try:
# 加载数据
dal = DataAccessLayer.load_from_file(temp_file)
assert dal.shape == (2, 3)
assert len(dal.columns) == 3
finally:
os.unlink(temp_file)
def test_load_empty_file(self):
"""测试加载空文件。"""
# 创建空的 CSV 文件
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8') as f:
f.write('id,name\n') # 只有表头,没有数据
temp_file = f.name
try:
# 应该抛出 DataLoadError
with pytest.raises(DataLoadError, match="为空"):
DataAccessLayer.load_from_file(temp_file)
finally:
os.unlink(temp_file)
def test_load_invalid_file(self):
"""测试加载不存在的文件。"""
with pytest.raises(DataLoadError):
DataAccessLayer.load_from_file('nonexistent_file.csv')
def test_get_profile_basic(self):
"""测试生成基本数据画像。"""
# 创建测试数据
df = pd.DataFrame({
'id': [1, 2, 3, 4, 5],
'name': ['A', 'B', 'C', 'D', 'E'],
'value': [10, 20, 30, 40, 50],
'status': ['open', 'closed', 'open', 'closed', 'open']
})
dal = DataAccessLayer(df, file_path='test.csv')
profile = dal.get_profile()
# 验证基本信息
assert profile.file_path == 'test.csv'
assert profile.row_count == 5
assert profile.column_count == 4
assert len(profile.columns) == 4
# 验证列信息
col_names = [col.name for col in profile.columns]
assert 'id' in col_names
assert 'name' in col_names
assert 'value' in col_names
assert 'status' in col_names
def test_get_profile_with_missing_values(self):
"""测试包含缺失值的数据画像。"""
df = pd.DataFrame({
'id': [1, 2, 3, 4, 5],
'value': [10, None, 30, None, 50]
})
dal = DataAccessLayer(df)
profile = dal.get_profile()
# 查找 value 列
value_col = next(col for col in profile.columns if col.name == 'value')
# 验证缺失率
assert value_col.missing_rate == 0.4 # 2/5 = 0.4
def test_column_type_inference_numeric(self):
"""测试数值类型推断。"""
df = pd.DataFrame({
'int_col': [1, 2, 3, 4, 5],
'float_col': [1.1, 2.2, 3.3, 4.4, 5.5]
})
dal = DataAccessLayer(df)
profile = dal.get_profile()
int_col = next(col for col in profile.columns if col.name == 'int_col')
float_col = next(col for col in profile.columns if col.name == 'float_col')
assert int_col.dtype == 'numeric'
assert float_col.dtype == 'numeric'
# 验证统计信息
assert 'mean' in int_col.statistics
assert 'std' in int_col.statistics
assert 'min' in int_col.statistics
assert 'max' in int_col.statistics
def test_column_type_inference_categorical(self):
"""测试分类类型推断。"""
df = pd.DataFrame({
'status': ['open', 'closed', 'open', 'closed', 'open'] * 20
})
dal = DataAccessLayer(df)
profile = dal.get_profile()
status_col = profile.columns[0]
assert status_col.dtype == 'categorical'
# 验证统计信息
assert 'top_values' in status_col.statistics
assert 'num_categories' in status_col.statistics
def test_column_type_inference_datetime(self):
"""测试日期时间类型推断。"""
df = pd.DataFrame({
'date': pd.date_range('2020-01-01', periods=10)
})
dal = DataAccessLayer(df)
profile = dal.get_profile()
date_col = profile.columns[0]
assert date_col.dtype == 'datetime'
def test_sample_values_limit(self):
"""测试示例值数量限制。"""
df = pd.DataFrame({
'id': list(range(100))
})
dal = DataAccessLayer(df)
profile = dal.get_profile()
id_col = profile.columns[0]
# 示例值应该最多5个
assert len(id_col.sample_values) <= 5
def test_sanitize_result_dataframe(self):
"""测试结果过滤 - DataFrame。"""
df = pd.DataFrame({
'id': list(range(200)),
'value': list(range(200))
})
dal = DataAccessLayer(df)
# 模拟工具返回大量数据
result = {'data': df}
sanitized = dal._sanitize_result(result)
# 验证返回的数据应该被截断到100行
assert len(sanitized['data']) <= 100
def test_sanitize_result_series(self):
"""测试结果过滤 - Series。"""
df = pd.DataFrame({
'id': list(range(200))
})
dal = DataAccessLayer(df)
# 模拟工具返回 Series
result = {'data': df['id']}
sanitized = dal._sanitize_result(result)
# 验证:返回的数据应该被截断
assert len(sanitized['data']) <= 100
def test_large_dataset_sampling(self):
"""测试大数据集采样。"""
# 创建超过100万行的临时文件
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8') as f:
f.write('id,value\n')
# 写入少量数据用于测试(实际测试大数据集会很慢)
for i in range(1000):
f.write(f'{i},{i*10}\n')
temp_file = f.name
try:
dal = DataAccessLayer.load_from_file(temp_file)
# 验证数据被加载
assert dal.shape[0] == 1000
finally:
os.unlink(temp_file)
class TestDataAccessLayerIntegration:
"""数据访问层的集成测试。"""
def test_end_to_end_workflow(self):
"""测试端到端工作流程。"""
# 创建测试数据
df = pd.DataFrame({
'id': [1, 2, 3, 4, 5],
'status': ['open', 'closed', 'open', 'closed', 'pending'],
'value': [100, 200, 150, 300, 250],
'created_at': pd.date_range('2020-01-01', periods=5)
})
# 保存到临时文件
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8') as f:
df.to_csv(f.name, index=False)
temp_file = f.name
try:
# 1. 加载数据
dal = DataAccessLayer.load_from_file(temp_file)
# 2. 生成数据画像
profile = dal.get_profile()
# 3. 验证数据画像
assert profile.row_count == 5
assert profile.column_count == 4
# 4. 验证列类型推断
col_types = {col.name: col.dtype for col in profile.columns}
assert col_types['id'] == 'numeric'
assert col_types['status'] == 'categorical'
assert col_types['value'] == 'numeric'
assert col_types['created_at'] == 'datetime'
# 5. 验证统计信息
value_col = next(col for col in profile.columns if col.name == 'value')
assert 'mean' in value_col.statistics
assert value_col.statistics['mean'] == 200.0
finally:
os.unlink(temp_file)