"""数据访问层的单元测试。""" import pytest import pandas as pd import tempfile import os from pathlib import Path from src.data_access import DataAccessLayer, DataLoadError class TestDataAccessLayer: """数据访问层的单元测试。""" def test_load_utf8_csv(self): """测试加载 UTF-8 编码的 CSV 文件。""" # 创建临时 CSV 文件 with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8') as f: f.write('id,name,value\n') f.write('1,测试,100\n') f.write('2,数据,200\n') temp_file = f.name try: # 加载数据 dal = DataAccessLayer.load_from_file(temp_file) assert dal.shape == (2, 3) assert 'id' in dal.columns assert 'name' in dal.columns assert 'value' in dal.columns finally: os.unlink(temp_file) def test_load_gbk_csv(self): """测试加载 GBK 编码的 CSV 文件。""" # 创建临时 GBK 编码的 CSV 文件 with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='gbk') as f: f.write('编号,名称,数值\n') f.write('1,测试,100\n') f.write('2,数据,200\n') temp_file = f.name try: # 加载数据 dal = DataAccessLayer.load_from_file(temp_file) assert dal.shape == (2, 3) assert len(dal.columns) == 3 finally: os.unlink(temp_file) def test_load_empty_file(self): """测试加载空文件。""" # 创建空的 CSV 文件 with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8') as f: f.write('id,name\n') # 只有表头,没有数据 temp_file = f.name try: # 应该抛出 DataLoadError with pytest.raises(DataLoadError, match="为空"): DataAccessLayer.load_from_file(temp_file) finally: os.unlink(temp_file) def test_load_invalid_file(self): """测试加载不存在的文件。""" with pytest.raises(DataLoadError): DataAccessLayer.load_from_file('nonexistent_file.csv') def test_get_profile_basic(self): """测试生成基本数据画像。""" # 创建测试数据 df = pd.DataFrame({ 'id': [1, 2, 3, 4, 5], 'name': ['A', 'B', 'C', 'D', 'E'], 'value': [10, 20, 30, 40, 50], 'status': ['open', 'closed', 'open', 'closed', 'open'] }) dal = DataAccessLayer(df, file_path='test.csv') profile = dal.get_profile() # 验证基本信息 assert profile.file_path == 'test.csv' assert profile.row_count == 5 assert profile.column_count == 4 assert len(profile.columns) == 4 # 验证列信息 col_names = [col.name for col in profile.columns] assert 'id' in col_names assert 'name' in col_names assert 'value' in col_names assert 'status' in col_names def test_get_profile_with_missing_values(self): """测试包含缺失值的数据画像。""" df = pd.DataFrame({ 'id': [1, 2, 3, 4, 5], 'value': [10, None, 30, None, 50] }) dal = DataAccessLayer(df) profile = dal.get_profile() # 查找 value 列 value_col = next(col for col in profile.columns if col.name == 'value') # 验证缺失率 assert value_col.missing_rate == 0.4 # 2/5 = 0.4 def test_column_type_inference_numeric(self): """测试数值类型推断。""" df = pd.DataFrame({ 'int_col': [1, 2, 3, 4, 5], 'float_col': [1.1, 2.2, 3.3, 4.4, 5.5] }) dal = DataAccessLayer(df) profile = dal.get_profile() int_col = next(col for col in profile.columns if col.name == 'int_col') float_col = next(col for col in profile.columns if col.name == 'float_col') assert int_col.dtype == 'numeric' assert float_col.dtype == 'numeric' # 验证统计信息 assert 'mean' in int_col.statistics assert 'std' in int_col.statistics assert 'min' in int_col.statistics assert 'max' in int_col.statistics def test_column_type_inference_categorical(self): """测试分类类型推断。""" df = pd.DataFrame({ 'status': ['open', 'closed', 'open', 'closed', 'open'] * 20 }) dal = DataAccessLayer(df) profile = dal.get_profile() status_col = profile.columns[0] assert status_col.dtype == 'categorical' # 验证统计信息 assert 'top_values' in status_col.statistics assert 'num_categories' in status_col.statistics def test_column_type_inference_datetime(self): """测试日期时间类型推断。""" df = pd.DataFrame({ 'date': pd.date_range('2020-01-01', periods=10) }) dal = DataAccessLayer(df) profile = dal.get_profile() date_col = profile.columns[0] assert date_col.dtype == 'datetime' def test_sample_values_limit(self): """测试示例值数量限制。""" df = pd.DataFrame({ 'id': list(range(100)) }) dal = DataAccessLayer(df) profile = dal.get_profile() id_col = profile.columns[0] # 示例值应该最多5个 assert len(id_col.sample_values) <= 5 def test_sanitize_result_dataframe(self): """测试结果过滤 - DataFrame。""" df = pd.DataFrame({ 'id': list(range(200)), 'value': list(range(200)) }) dal = DataAccessLayer(df) # 模拟工具返回大量数据 result = {'data': df} sanitized = dal._sanitize_result(result) # 验证:返回的数据应该被截断到100行 assert len(sanitized['data']) <= 100 def test_sanitize_result_series(self): """测试结果过滤 - Series。""" df = pd.DataFrame({ 'id': list(range(200)) }) dal = DataAccessLayer(df) # 模拟工具返回 Series result = {'data': df['id']} sanitized = dal._sanitize_result(result) # 验证:返回的数据应该被截断 assert len(sanitized['data']) <= 100 def test_large_dataset_sampling(self): """测试大数据集采样。""" # 创建超过100万行的临时文件 with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8') as f: f.write('id,value\n') # 写入少量数据用于测试(实际测试大数据集会很慢) for i in range(1000): f.write(f'{i},{i*10}\n') temp_file = f.name try: dal = DataAccessLayer.load_from_file(temp_file) # 验证数据被加载 assert dal.shape[0] == 1000 finally: os.unlink(temp_file) class TestDataAccessLayerIntegration: """数据访问层的集成测试。""" def test_end_to_end_workflow(self): """测试端到端工作流程。""" # 创建测试数据 df = pd.DataFrame({ 'id': [1, 2, 3, 4, 5], 'status': ['open', 'closed', 'open', 'closed', 'pending'], 'value': [100, 200, 150, 300, 250], 'created_at': pd.date_range('2020-01-01', periods=5) }) # 保存到临时文件 with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8') as f: df.to_csv(f.name, index=False) temp_file = f.name try: # 1. 加载数据 dal = DataAccessLayer.load_from_file(temp_file) # 2. 生成数据画像 profile = dal.get_profile() # 3. 验证数据画像 assert profile.row_count == 5 assert profile.column_count == 4 # 4. 验证列类型推断 col_types = {col.name: col.dtype for col in profile.columns} assert col_types['id'] == 'numeric' assert col_types['status'] == 'categorical' assert col_types['value'] == 'numeric' assert col_types['created_at'] == 'datetime' # 5. 验证统计信息 value_col = next(col for col in profile.columns if col.name == 'value') assert 'mean' in value_col.statistics assert value_col.statistics['mean'] == 200.0 finally: os.unlink(temp_file)