"""工具系统的单元测试。""" import pytest import pandas as pd import numpy as np from datetime import datetime, timedelta from src.tools.base import AnalysisTool, ToolRegistry from src.tools.query_tools import ( GetColumnDistributionTool, GetValueCountsTool, GetTimeSeriesTool, GetCorrelationTool ) from src.tools.stats_tools import ( CalculateStatisticsTool, PerformGroupbyTool, DetectOutliersTool, CalculateTrendTool ) from src.models import DataProfile, ColumnInfo class TestGetColumnDistributionTool: """测试列分布工具。""" def test_basic_functionality(self): """测试基本功能。""" tool = GetColumnDistributionTool() df = pd.DataFrame({ 'status': ['open', 'closed', 'open', 'pending', 'closed', 'open'] }) result = tool.execute(df, column='status') assert 'distribution' in result assert result['column'] == 'status' assert result['total_count'] == 6 assert result['unique_count'] == 3 assert len(result['distribution']) == 3 def test_top_n_limit(self): """测试 top_n 参数限制。""" tool = GetColumnDistributionTool() df = pd.DataFrame({ 'value': list(range(20)) }) result = tool.execute(df, column='value', top_n=5) assert len(result['distribution']) == 5 def test_nonexistent_column(self): """测试不存在的列。""" tool = GetColumnDistributionTool() df = pd.DataFrame({'col1': [1, 2, 3]}) result = tool.execute(df, column='nonexistent') assert 'error' in result class TestGetValueCountsTool: """测试值计数工具。""" def test_basic_functionality(self): """测试基本功能。""" tool = GetValueCountsTool() df = pd.DataFrame({ 'category': ['A', 'B', 'A', 'C', 'B', 'A'] }) result = tool.execute(df, column='category') assert 'value_counts' in result assert result['value_counts']['A'] == 3 assert result['value_counts']['B'] == 2 assert result['value_counts']['C'] == 1 def test_normalized_counts(self): """测试归一化计数。""" tool = GetValueCountsTool() df = pd.DataFrame({ 'category': ['A', 'A', 'B', 'B'] }) result = tool.execute(df, column='category', normalize=True) assert result['normalized'] is True assert abs(result['value_counts']['A'] - 0.5) < 0.01 assert abs(result['value_counts']['B'] - 0.5) < 0.01 class TestGetTimeSeriesTool: """测试时间序列工具。""" def test_basic_functionality(self): """测试基本功能。""" tool = GetTimeSeriesTool() dates = pd.date_range('2020-01-01', periods=10, freq='D') df = pd.DataFrame({ 'date': dates, 'value': range(10) }) result = tool.execute(df, time_column='date', value_column='value', aggregation='sum') assert 'time_series' in result assert result['time_column'] == 'date' assert result['aggregation'] == 'sum' assert len(result['time_series']) > 0 def test_count_aggregation(self): """测试计数聚合。""" tool = GetTimeSeriesTool() dates = pd.date_range('2020-01-01', periods=5, freq='D') df = pd.DataFrame({'date': dates}) result = tool.execute(df, time_column='date', aggregation='count') assert 'time_series' in result assert len(result['time_series']) > 0 def test_output_limit(self): """测试输出限制(不超过100行)。""" tool = GetTimeSeriesTool() dates = pd.date_range('2020-01-01', periods=200, freq='D') df = pd.DataFrame({'date': dates}) result = tool.execute(df, time_column='date') assert len(result['time_series']) <= 100 assert result['total_points'] == 200 assert result['returned_points'] == 100 class TestGetCorrelationTool: """测试相关性分析工具。""" def test_basic_functionality(self): """测试基本功能。""" tool = GetCorrelationTool() df = pd.DataFrame({ 'x': [1, 2, 3, 4, 5], 'y': [2, 4, 6, 8, 10], 'z': [1, 1, 1, 1, 1] }) result = tool.execute(df) assert 'correlation_matrix' in result assert 'x' in result['correlation_matrix'] assert 'y' in result['correlation_matrix'] # x 和 y 完全正相关 assert abs(result['correlation_matrix']['x']['y'] - 1.0) < 0.01 def test_insufficient_numeric_columns(self): """测试数值列不足的情况。""" tool = GetCorrelationTool() df = pd.DataFrame({ 'x': [1, 2, 3], 'text': ['a', 'b', 'c'] }) result = tool.execute(df) assert 'error' in result class TestCalculateStatisticsTool: """测试统计计算工具。""" def test_basic_functionality(self): """测试基本功能。""" tool = CalculateStatisticsTool() df = pd.DataFrame({ 'values': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] }) result = tool.execute(df, column='values') assert result['mean'] == 5.5 assert result['median'] == 5.5 assert result['min'] == 1 assert result['max'] == 10 assert result['count'] == 10 def test_non_numeric_column(self): """测试非数值列。""" tool = CalculateStatisticsTool() df = pd.DataFrame({ 'text': ['a', 'b', 'c'] }) result = tool.execute(df, column='text') assert 'error' in result class TestPerformGroupbyTool: """测试分组聚合工具。""" def test_basic_functionality(self): """测试基本功能。""" tool = PerformGroupbyTool() df = pd.DataFrame({ 'category': ['A', 'B', 'A', 'B', 'A'], 'value': [10, 20, 30, 40, 50] }) result = tool.execute(df, group_by='category', value_column='value', aggregation='sum') assert 'groups' in result assert len(result['groups']) == 2 # 找到 A 组的总和 group_a = next(g for g in result['groups'] if g['group'] == 'A') assert group_a['value'] == 90 # 10 + 30 + 50 def test_count_aggregation(self): """测试计数聚合。""" tool = PerformGroupbyTool() df = pd.DataFrame({ 'category': ['A', 'B', 'A', 'B', 'A'] }) result = tool.execute(df, group_by='category') assert len(result['groups']) == 2 group_a = next(g for g in result['groups'] if g['group'] == 'A') assert group_a['value'] == 3 def test_output_limit(self): """测试输出限制(不超过100组)。""" tool = PerformGroupbyTool() df = pd.DataFrame({ 'category': [f'cat_{i}' for i in range(200)], 'value': range(200) }) result = tool.execute(df, group_by='category', value_column='value', aggregation='sum') assert len(result['groups']) <= 100 assert result['total_groups'] == 200 assert result['returned_groups'] == 100 class TestDetectOutliersTool: """测试异常值检测工具。""" def test_iqr_method(self): """测试 IQR 方法。""" tool = DetectOutliersTool() # 创建包含明显异常值的数据 df = pd.DataFrame({ 'values': [1, 2, 3, 4, 5, 6, 7, 8, 9, 100] }) result = tool.execute(df, column='values', method='iqr') assert result['outlier_count'] > 0 assert 100 in result['outlier_values'] def test_zscore_method(self): """测试 Z-score 方法。""" tool = DetectOutliersTool() df = pd.DataFrame({ 'values': [1, 2, 3, 4, 5, 6, 7, 8, 9, 100] }) result = tool.execute(df, column='values', method='zscore', threshold=2) assert result['outlier_count'] > 0 assert result['method'] == 'zscore' class TestCalculateTrendTool: """测试趋势计算工具。""" def test_increasing_trend(self): """测试上升趋势。""" tool = CalculateTrendTool() dates = pd.date_range('2020-01-01', periods=10, freq='D') df = pd.DataFrame({ 'date': dates, 'value': range(10) }) result = tool.execute(df, time_column='date', value_column='value') assert result['trend'] == 'increasing' assert result['slope'] > 0 assert result['r_squared'] > 0.9 # 完美线性关系 def test_decreasing_trend(self): """测试下降趋势。""" tool = CalculateTrendTool() dates = pd.date_range('2020-01-01', periods=10, freq='D') df = pd.DataFrame({ 'date': dates, 'value': list(range(10, 0, -1)) }) result = tool.execute(df, time_column='date', value_column='value') assert result['trend'] == 'decreasing' assert result['slope'] < 0 class TestToolParameterValidation: """测试工具参数验证。""" def test_missing_required_parameter(self): """测试缺少必需参数。""" tool = GetColumnDistributionTool() df = pd.DataFrame({'col': [1, 2, 3]}) # 不提供必需的 column 参数 result = tool.execute(df) # 应该返回错误或引发异常 assert 'error' in result or result is None def test_invalid_aggregation_method(self): """测试无效的聚合方法。""" tool = PerformGroupbyTool() df = pd.DataFrame({ 'category': ['A', 'B'], 'value': [1, 2] }) result = tool.execute(df, group_by='category', value_column='value', aggregation='invalid') assert 'error' in result class TestToolErrorHandling: """测试工具错误处理。""" def test_empty_dataframe(self): """测试空 DataFrame。""" tool = CalculateStatisticsTool() df = pd.DataFrame() result = tool.execute(df, column='nonexistent') assert 'error' in result def test_all_null_values(self): """测试全部为空值的列。""" tool = CalculateStatisticsTool() df = pd.DataFrame({ 'values': [None, None, None] }) result = tool.execute(df, column='values') # 应该处理空值情况 assert 'error' in result or result['count'] == 0 def test_invalid_date_column(self): """测试无效的日期列。""" tool = GetTimeSeriesTool() df = pd.DataFrame({ 'not_date': ['a', 'b', 'c'] }) result = tool.execute(df, time_column='not_date') assert 'error' in result class TestToolRegistry: """测试工具注册表。""" def test_register_and_retrieve(self): """测试注册和检索工具。""" registry = ToolRegistry() tool = GetColumnDistributionTool() registry.register(tool) retrieved = registry.get_tool(tool.name) assert retrieved.name == tool.name def test_unregister(self): """测试注销工具。""" registry = ToolRegistry() tool = GetColumnDistributionTool() registry.register(tool) registry.unregister(tool.name) with pytest.raises(KeyError): registry.get_tool(tool.name) def test_list_tools(self): """测试列出所有工具。""" registry = ToolRegistry() tool1 = GetColumnDistributionTool() tool2 = GetValueCountsTool() registry.register(tool1) registry.register(tool2) tools = registry.list_tools() assert len(tools) == 2 assert tool1.name in tools assert tool2.name in tools def test_get_applicable_tools(self): """测试获取适用的工具。""" registry = ToolRegistry() # 注册所有工具 registry.register(GetColumnDistributionTool()) registry.register(CalculateStatisticsTool()) registry.register(GetTimeSeriesTool()) # 创建包含数值和时间列的数据画像 profile = DataProfile( file_path='test.csv', row_count=100, column_count=2, columns=[ ColumnInfo(name='value', dtype='numeric', missing_rate=0.0, unique_count=50), ColumnInfo(name='date', dtype='datetime', missing_rate=0.0, unique_count=100) ], inferred_type='unknown' ) applicable = registry.get_applicable_tools(profile) # 所有工具都应该适用(GetColumnDistributionTool 适用于所有数据) assert len(applicable) > 0 class TestToolManager: """测试工具管理器。""" def test_select_tools_for_datetime_data(self): """测试为包含时间字段的数据选择工具。""" from src.tools.tool_manager import ToolManager # 创建工具注册表并注册所有工具 registry = ToolRegistry() registry.register(GetTimeSeriesTool()) registry.register(CalculateTrendTool()) registry.register(GetColumnDistributionTool()) manager = ToolManager(registry) # 创建包含时间字段的数据画像 profile = DataProfile( file_path='test.csv', row_count=100, column_count=1, columns=[ ColumnInfo(name='date', dtype='datetime', missing_rate=0.0, unique_count=100) ], inferred_type='unknown', key_fields={}, quality_score=100.0, summary='Test data' ) tools = manager.select_tools(profile) tool_names = [tool.name for tool in tools] # 应该包含时间序列工具 assert 'get_time_series' in tool_names assert 'calculate_trend' in tool_names def test_select_tools_for_numeric_data(self): """测试为包含数值字段的数据选择工具。""" from src.tools.tool_manager import ToolManager registry = ToolRegistry() registry.register(CalculateStatisticsTool()) registry.register(DetectOutliersTool()) registry.register(GetCorrelationTool()) manager = ToolManager(registry) # 创建包含数值字段的数据画像 profile = DataProfile( file_path='test.csv', row_count=100, column_count=2, columns=[ ColumnInfo(name='value1', dtype='numeric', missing_rate=0.0, unique_count=50), ColumnInfo(name='value2', dtype='numeric', missing_rate=0.0, unique_count=50) ], inferred_type='unknown', key_fields={}, quality_score=100.0, summary='Test data' ) tools = manager.select_tools(profile) tool_names = [tool.name for tool in tools] # 应该包含统计工具 assert 'calculate_statistics' in tool_names assert 'detect_outliers' in tool_names assert 'get_correlation' in tool_names def test_select_tools_for_categorical_data(self): """测试为包含分类字段的数据选择工具。""" from src.tools.tool_manager import ToolManager registry = ToolRegistry() registry.register(GetColumnDistributionTool()) registry.register(GetValueCountsTool()) registry.register(PerformGroupbyTool()) manager = ToolManager(registry) # 创建包含分类字段的数据画像 profile = DataProfile( file_path='test.csv', row_count=100, column_count=1, columns=[ ColumnInfo(name='category', dtype='categorical', missing_rate=0.0, unique_count=5) ], inferred_type='unknown', key_fields={}, quality_score=100.0, summary='Test data' ) tools = manager.select_tools(profile) tool_names = [tool.name for tool in tools] # 应该包含分类工具 assert 'get_column_distribution' in tool_names assert 'get_value_counts' in tool_names assert 'perform_groupby' in tool_names def test_no_geo_tools_for_non_geo_data(self): """测试不为非地理数据选择地理工具。""" from src.tools.tool_manager import ToolManager registry = ToolRegistry() registry.register(GetColumnDistributionTool()) manager = ToolManager(registry) # 创建不包含地理字段的数据画像 profile = DataProfile( file_path='test.csv', row_count=100, column_count=1, columns=[ ColumnInfo(name='value', dtype='numeric', missing_rate=0.0, unique_count=50) ], inferred_type='unknown', key_fields={}, quality_score=100.0, summary='Test data' ) tools = manager.select_tools(profile) tool_names = [tool.name for tool in tools] # 不应该包含地理工具 assert 'create_map_visualization' not in tool_names def test_identify_missing_tools(self): """测试识别缺失的工具。""" from src.tools.tool_manager import ToolManager # 创建空的工具注册表 empty_registry = ToolRegistry() manager = ToolManager(empty_registry) # 创建包含时间字段的数据画像 profile = DataProfile( file_path='test.csv', row_count=100, column_count=1, columns=[ ColumnInfo(name='date', dtype='datetime', missing_rate=0.0, unique_count=100) ], inferred_type='unknown', key_fields={}, quality_score=100.0, summary='Test data' ) # 尝试选择工具 tools = manager.select_tools(profile) # 获取缺失的工具 missing = manager.get_missing_tools() # 应该识别出缺失的时间序列工具 assert len(missing) > 0 assert any(tool in missing for tool in ['get_time_series', 'calculate_trend']) def test_clear_missing_tools(self): """测试清空缺失工具列表。""" from src.tools.tool_manager import ToolManager empty_registry = ToolRegistry() manager = ToolManager(empty_registry) # 创建数据画像并选择工具(会记录缺失工具) profile = DataProfile( file_path='test.csv', row_count=100, column_count=1, columns=[ ColumnInfo(name='date', dtype='datetime', missing_rate=0.0, unique_count=100) ], inferred_type='unknown', key_fields={}, quality_score=100.0, summary='Test data' ) manager.select_tools(profile) assert len(manager.get_missing_tools()) > 0 # 清空缺失工具列表 manager.clear_missing_tools() assert len(manager.get_missing_tools()) == 0 def test_get_tool_descriptions(self): """测试获取工具描述。""" from src.tools.tool_manager import ToolManager registry = ToolRegistry() tool1 = GetColumnDistributionTool() tool2 = CalculateStatisticsTool() registry.register(tool1) registry.register(tool2) manager = ToolManager(registry) tools = [tool1, tool2] descriptions = manager.get_tool_descriptions(tools) assert len(descriptions) == 2 assert all('name' in desc for desc in descriptions) assert all('description' in desc for desc in descriptions) assert all('parameters' in desc for desc in descriptions) def test_tool_deduplication(self): """测试工具去重。""" from src.tools.tool_manager import ToolManager registry = ToolRegistry() # 注册一个工具,它可能被多个类别选中 tool = GetColumnDistributionTool() registry.register(tool) manager = ToolManager(registry) # 创建包含多种类型字段的数据画像 profile = DataProfile( file_path='test.csv', row_count=100, column_count=2, columns=[ ColumnInfo(name='category', dtype='categorical', missing_rate=0.0, unique_count=5), ColumnInfo(name='value', dtype='numeric', missing_rate=0.0, unique_count=50) ], inferred_type='unknown', key_fields={}, quality_score=100.0, summary='Test data' ) tools = manager.select_tools(profile) tool_names = [tool.name for tool in tools] # 工具名称应该是唯一的(没有重复) assert len(tool_names) == len(set(tool_names))