681 lines
22 KiB
Python
681 lines
22 KiB
Python
"""工具系统的单元测试。"""
|
||
|
||
import pytest
|
||
import pandas as pd
|
||
import numpy as np
|
||
from datetime import datetime, timedelta
|
||
|
||
from src.tools.base import AnalysisTool, ToolRegistry
|
||
from src.tools.query_tools import (
|
||
GetColumnDistributionTool,
|
||
GetValueCountsTool,
|
||
GetTimeSeriesTool,
|
||
GetCorrelationTool
|
||
)
|
||
from src.tools.stats_tools import (
|
||
CalculateStatisticsTool,
|
||
PerformGroupbyTool,
|
||
DetectOutliersTool,
|
||
CalculateTrendTool
|
||
)
|
||
from src.models import DataProfile, ColumnInfo
|
||
|
||
|
||
class TestGetColumnDistributionTool:
|
||
"""测试列分布工具。"""
|
||
|
||
def test_basic_functionality(self):
|
||
"""测试基本功能。"""
|
||
tool = GetColumnDistributionTool()
|
||
df = pd.DataFrame({
|
||
'status': ['open', 'closed', 'open', 'pending', 'closed', 'open']
|
||
})
|
||
|
||
result = tool.execute(df, column='status')
|
||
|
||
assert 'distribution' in result
|
||
assert result['column'] == 'status'
|
||
assert result['total_count'] == 6
|
||
assert result['unique_count'] == 3
|
||
assert len(result['distribution']) == 3
|
||
|
||
def test_top_n_limit(self):
|
||
"""测试 top_n 参数限制。"""
|
||
tool = GetColumnDistributionTool()
|
||
df = pd.DataFrame({
|
||
'value': list(range(20))
|
||
})
|
||
|
||
result = tool.execute(df, column='value', top_n=5)
|
||
|
||
assert len(result['distribution']) == 5
|
||
|
||
def test_nonexistent_column(self):
|
||
"""测试不存在的列。"""
|
||
tool = GetColumnDistributionTool()
|
||
df = pd.DataFrame({'col1': [1, 2, 3]})
|
||
|
||
result = tool.execute(df, column='nonexistent')
|
||
|
||
assert 'error' in result
|
||
|
||
|
||
class TestGetValueCountsTool:
|
||
"""测试值计数工具。"""
|
||
|
||
def test_basic_functionality(self):
|
||
"""测试基本功能。"""
|
||
tool = GetValueCountsTool()
|
||
df = pd.DataFrame({
|
||
'category': ['A', 'B', 'A', 'C', 'B', 'A']
|
||
})
|
||
|
||
result = tool.execute(df, column='category')
|
||
|
||
assert 'value_counts' in result
|
||
assert result['value_counts']['A'] == 3
|
||
assert result['value_counts']['B'] == 2
|
||
assert result['value_counts']['C'] == 1
|
||
|
||
def test_normalized_counts(self):
|
||
"""测试归一化计数。"""
|
||
tool = GetValueCountsTool()
|
||
df = pd.DataFrame({
|
||
'category': ['A', 'A', 'B', 'B']
|
||
})
|
||
|
||
result = tool.execute(df, column='category', normalize=True)
|
||
|
||
assert result['normalized'] is True
|
||
assert abs(result['value_counts']['A'] - 0.5) < 0.01
|
||
assert abs(result['value_counts']['B'] - 0.5) < 0.01
|
||
|
||
|
||
class TestGetTimeSeriesTool:
|
||
"""测试时间序列工具。"""
|
||
|
||
def test_basic_functionality(self):
|
||
"""测试基本功能。"""
|
||
tool = GetTimeSeriesTool()
|
||
dates = pd.date_range('2020-01-01', periods=10, freq='D')
|
||
df = pd.DataFrame({
|
||
'date': dates,
|
||
'value': range(10)
|
||
})
|
||
|
||
result = tool.execute(df, time_column='date', value_column='value', aggregation='sum')
|
||
|
||
assert 'time_series' in result
|
||
assert result['time_column'] == 'date'
|
||
assert result['aggregation'] == 'sum'
|
||
assert len(result['time_series']) > 0
|
||
|
||
def test_count_aggregation(self):
|
||
"""测试计数聚合。"""
|
||
tool = GetTimeSeriesTool()
|
||
dates = pd.date_range('2020-01-01', periods=5, freq='D')
|
||
df = pd.DataFrame({'date': dates})
|
||
|
||
result = tool.execute(df, time_column='date', aggregation='count')
|
||
|
||
assert 'time_series' in result
|
||
assert len(result['time_series']) > 0
|
||
|
||
def test_output_limit(self):
|
||
"""测试输出限制(不超过100行)。"""
|
||
tool = GetTimeSeriesTool()
|
||
dates = pd.date_range('2020-01-01', periods=200, freq='D')
|
||
df = pd.DataFrame({'date': dates})
|
||
|
||
result = tool.execute(df, time_column='date')
|
||
|
||
assert len(result['time_series']) <= 100
|
||
assert result['total_points'] == 200
|
||
assert result['returned_points'] == 100
|
||
|
||
|
||
class TestGetCorrelationTool:
|
||
"""测试相关性分析工具。"""
|
||
|
||
def test_basic_functionality(self):
|
||
"""测试基本功能。"""
|
||
tool = GetCorrelationTool()
|
||
df = pd.DataFrame({
|
||
'x': [1, 2, 3, 4, 5],
|
||
'y': [2, 4, 6, 8, 10],
|
||
'z': [1, 1, 1, 1, 1]
|
||
})
|
||
|
||
result = tool.execute(df)
|
||
|
||
assert 'correlation_matrix' in result
|
||
assert 'x' in result['correlation_matrix']
|
||
assert 'y' in result['correlation_matrix']
|
||
# x 和 y 完全正相关
|
||
assert abs(result['correlation_matrix']['x']['y'] - 1.0) < 0.01
|
||
|
||
def test_insufficient_numeric_columns(self):
|
||
"""测试数值列不足的情况。"""
|
||
tool = GetCorrelationTool()
|
||
df = pd.DataFrame({
|
||
'x': [1, 2, 3],
|
||
'text': ['a', 'b', 'c']
|
||
})
|
||
|
||
result = tool.execute(df)
|
||
|
||
assert 'error' in result
|
||
|
||
|
||
class TestCalculateStatisticsTool:
|
||
"""测试统计计算工具。"""
|
||
|
||
def test_basic_functionality(self):
|
||
"""测试基本功能。"""
|
||
tool = CalculateStatisticsTool()
|
||
df = pd.DataFrame({
|
||
'values': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||
})
|
||
|
||
result = tool.execute(df, column='values')
|
||
|
||
assert result['mean'] == 5.5
|
||
assert result['median'] == 5.5
|
||
assert result['min'] == 1
|
||
assert result['max'] == 10
|
||
assert result['count'] == 10
|
||
|
||
def test_non_numeric_column(self):
|
||
"""测试非数值列。"""
|
||
tool = CalculateStatisticsTool()
|
||
df = pd.DataFrame({
|
||
'text': ['a', 'b', 'c']
|
||
})
|
||
|
||
result = tool.execute(df, column='text')
|
||
|
||
assert 'error' in result
|
||
|
||
|
||
class TestPerformGroupbyTool:
|
||
"""测试分组聚合工具。"""
|
||
|
||
def test_basic_functionality(self):
|
||
"""测试基本功能。"""
|
||
tool = PerformGroupbyTool()
|
||
df = pd.DataFrame({
|
||
'category': ['A', 'B', 'A', 'B', 'A'],
|
||
'value': [10, 20, 30, 40, 50]
|
||
})
|
||
|
||
result = tool.execute(df, group_by='category', value_column='value', aggregation='sum')
|
||
|
||
assert 'groups' in result
|
||
assert len(result['groups']) == 2
|
||
# 找到 A 组的总和
|
||
group_a = next(g for g in result['groups'] if g['group'] == 'A')
|
||
assert group_a['value'] == 90 # 10 + 30 + 50
|
||
|
||
def test_count_aggregation(self):
|
||
"""测试计数聚合。"""
|
||
tool = PerformGroupbyTool()
|
||
df = pd.DataFrame({
|
||
'category': ['A', 'B', 'A', 'B', 'A']
|
||
})
|
||
|
||
result = tool.execute(df, group_by='category')
|
||
|
||
assert len(result['groups']) == 2
|
||
group_a = next(g for g in result['groups'] if g['group'] == 'A')
|
||
assert group_a['value'] == 3
|
||
|
||
def test_output_limit(self):
|
||
"""测试输出限制(不超过100组)。"""
|
||
tool = PerformGroupbyTool()
|
||
df = pd.DataFrame({
|
||
'category': [f'cat_{i}' for i in range(200)],
|
||
'value': range(200)
|
||
})
|
||
|
||
result = tool.execute(df, group_by='category', value_column='value', aggregation='sum')
|
||
|
||
assert len(result['groups']) <= 100
|
||
assert result['total_groups'] == 200
|
||
assert result['returned_groups'] == 100
|
||
|
||
|
||
class TestDetectOutliersTool:
|
||
"""测试异常值检测工具。"""
|
||
|
||
def test_iqr_method(self):
|
||
"""测试 IQR 方法。"""
|
||
tool = DetectOutliersTool()
|
||
# 创建包含明显异常值的数据
|
||
df = pd.DataFrame({
|
||
'values': [1, 2, 3, 4, 5, 6, 7, 8, 9, 100]
|
||
})
|
||
|
||
result = tool.execute(df, column='values', method='iqr')
|
||
|
||
assert result['outlier_count'] > 0
|
||
assert 100 in result['outlier_values']
|
||
|
||
def test_zscore_method(self):
|
||
"""测试 Z-score 方法。"""
|
||
tool = DetectOutliersTool()
|
||
df = pd.DataFrame({
|
||
'values': [1, 2, 3, 4, 5, 6, 7, 8, 9, 100]
|
||
})
|
||
|
||
result = tool.execute(df, column='values', method='zscore', threshold=2)
|
||
|
||
assert result['outlier_count'] > 0
|
||
assert result['method'] == 'zscore'
|
||
|
||
|
||
class TestCalculateTrendTool:
|
||
"""测试趋势计算工具。"""
|
||
|
||
def test_increasing_trend(self):
|
||
"""测试上升趋势。"""
|
||
tool = CalculateTrendTool()
|
||
dates = pd.date_range('2020-01-01', periods=10, freq='D')
|
||
df = pd.DataFrame({
|
||
'date': dates,
|
||
'value': range(10)
|
||
})
|
||
|
||
result = tool.execute(df, time_column='date', value_column='value')
|
||
|
||
assert result['trend'] == 'increasing'
|
||
assert result['slope'] > 0
|
||
assert result['r_squared'] > 0.9 # 完美线性关系
|
||
|
||
def test_decreasing_trend(self):
|
||
"""测试下降趋势。"""
|
||
tool = CalculateTrendTool()
|
||
dates = pd.date_range('2020-01-01', periods=10, freq='D')
|
||
df = pd.DataFrame({
|
||
'date': dates,
|
||
'value': list(range(10, 0, -1))
|
||
})
|
||
|
||
result = tool.execute(df, time_column='date', value_column='value')
|
||
|
||
assert result['trend'] == 'decreasing'
|
||
assert result['slope'] < 0
|
||
|
||
|
||
class TestToolParameterValidation:
|
||
"""测试工具参数验证。"""
|
||
|
||
def test_missing_required_parameter(self):
|
||
"""测试缺少必需参数。"""
|
||
tool = GetColumnDistributionTool()
|
||
df = pd.DataFrame({'col': [1, 2, 3]})
|
||
|
||
# 不提供必需的 column 参数
|
||
result = tool.execute(df)
|
||
|
||
# 应该返回错误或引发异常
|
||
assert 'error' in result or result is None
|
||
|
||
def test_invalid_aggregation_method(self):
|
||
"""测试无效的聚合方法。"""
|
||
tool = PerformGroupbyTool()
|
||
df = pd.DataFrame({
|
||
'category': ['A', 'B'],
|
||
'value': [1, 2]
|
||
})
|
||
|
||
result = tool.execute(df, group_by='category', value_column='value', aggregation='invalid')
|
||
|
||
assert 'error' in result
|
||
|
||
|
||
class TestToolErrorHandling:
|
||
"""测试工具错误处理。"""
|
||
|
||
def test_empty_dataframe(self):
|
||
"""测试空 DataFrame。"""
|
||
tool = CalculateStatisticsTool()
|
||
df = pd.DataFrame()
|
||
|
||
result = tool.execute(df, column='nonexistent')
|
||
|
||
assert 'error' in result
|
||
|
||
def test_all_null_values(self):
|
||
"""测试全部为空值的列。"""
|
||
tool = CalculateStatisticsTool()
|
||
df = pd.DataFrame({
|
||
'values': [None, None, None]
|
||
})
|
||
|
||
result = tool.execute(df, column='values')
|
||
|
||
# 应该处理空值情况
|
||
assert 'error' in result or result['count'] == 0
|
||
|
||
def test_invalid_date_column(self):
|
||
"""测试无效的日期列。"""
|
||
tool = GetTimeSeriesTool()
|
||
df = pd.DataFrame({
|
||
'not_date': ['a', 'b', 'c']
|
||
})
|
||
|
||
result = tool.execute(df, time_column='not_date')
|
||
|
||
assert 'error' in result
|
||
|
||
|
||
class TestToolRegistry:
|
||
"""测试工具注册表。"""
|
||
|
||
def test_register_and_retrieve(self):
|
||
"""测试注册和检索工具。"""
|
||
registry = ToolRegistry()
|
||
tool = GetColumnDistributionTool()
|
||
|
||
registry.register(tool)
|
||
retrieved = registry.get_tool(tool.name)
|
||
|
||
assert retrieved.name == tool.name
|
||
|
||
def test_unregister(self):
|
||
"""测试注销工具。"""
|
||
registry = ToolRegistry()
|
||
tool = GetColumnDistributionTool()
|
||
|
||
registry.register(tool)
|
||
registry.unregister(tool.name)
|
||
|
||
with pytest.raises(KeyError):
|
||
registry.get_tool(tool.name)
|
||
|
||
def test_list_tools(self):
|
||
"""测试列出所有工具。"""
|
||
registry = ToolRegistry()
|
||
tool1 = GetColumnDistributionTool()
|
||
tool2 = GetValueCountsTool()
|
||
|
||
registry.register(tool1)
|
||
registry.register(tool2)
|
||
|
||
tools = registry.list_tools()
|
||
assert len(tools) == 2
|
||
assert tool1.name in tools
|
||
assert tool2.name in tools
|
||
|
||
def test_get_applicable_tools(self):
|
||
"""测试获取适用的工具。"""
|
||
registry = ToolRegistry()
|
||
|
||
# 注册所有工具
|
||
registry.register(GetColumnDistributionTool())
|
||
registry.register(CalculateStatisticsTool())
|
||
registry.register(GetTimeSeriesTool())
|
||
|
||
# 创建包含数值和时间列的数据画像
|
||
profile = DataProfile(
|
||
file_path='test.csv',
|
||
row_count=100,
|
||
column_count=2,
|
||
columns=[
|
||
ColumnInfo(name='value', dtype='numeric', missing_rate=0.0, unique_count=50),
|
||
ColumnInfo(name='date', dtype='datetime', missing_rate=0.0, unique_count=100)
|
||
],
|
||
inferred_type='unknown'
|
||
)
|
||
|
||
applicable = registry.get_applicable_tools(profile)
|
||
|
||
# 所有工具都应该适用(GetColumnDistributionTool 适用于所有数据)
|
||
assert len(applicable) > 0
|
||
|
||
|
||
|
||
class TestToolManager:
|
||
"""测试工具管理器。"""
|
||
|
||
def test_select_tools_for_datetime_data(self):
|
||
"""测试为包含时间字段的数据选择工具。"""
|
||
from src.tools.tool_manager import ToolManager
|
||
|
||
# 创建工具注册表并注册所有工具
|
||
registry = ToolRegistry()
|
||
registry.register(GetTimeSeriesTool())
|
||
registry.register(CalculateTrendTool())
|
||
registry.register(GetColumnDistributionTool())
|
||
|
||
manager = ToolManager(registry)
|
||
|
||
# 创建包含时间字段的数据画像
|
||
profile = DataProfile(
|
||
file_path='test.csv',
|
||
row_count=100,
|
||
column_count=1,
|
||
columns=[
|
||
ColumnInfo(name='date', dtype='datetime', missing_rate=0.0, unique_count=100)
|
||
],
|
||
inferred_type='unknown',
|
||
key_fields={},
|
||
quality_score=100.0,
|
||
summary='Test data'
|
||
)
|
||
|
||
tools = manager.select_tools(profile)
|
||
tool_names = [tool.name for tool in tools]
|
||
|
||
# 应该包含时间序列工具
|
||
assert 'get_time_series' in tool_names
|
||
assert 'calculate_trend' in tool_names
|
||
|
||
def test_select_tools_for_numeric_data(self):
|
||
"""测试为包含数值字段的数据选择工具。"""
|
||
from src.tools.tool_manager import ToolManager
|
||
|
||
registry = ToolRegistry()
|
||
registry.register(CalculateStatisticsTool())
|
||
registry.register(DetectOutliersTool())
|
||
registry.register(GetCorrelationTool())
|
||
|
||
manager = ToolManager(registry)
|
||
|
||
# 创建包含数值字段的数据画像
|
||
profile = DataProfile(
|
||
file_path='test.csv',
|
||
row_count=100,
|
||
column_count=2,
|
||
columns=[
|
||
ColumnInfo(name='value1', dtype='numeric', missing_rate=0.0, unique_count=50),
|
||
ColumnInfo(name='value2', dtype='numeric', missing_rate=0.0, unique_count=50)
|
||
],
|
||
inferred_type='unknown',
|
||
key_fields={},
|
||
quality_score=100.0,
|
||
summary='Test data'
|
||
)
|
||
|
||
tools = manager.select_tools(profile)
|
||
tool_names = [tool.name for tool in tools]
|
||
|
||
# 应该包含统计工具
|
||
assert 'calculate_statistics' in tool_names
|
||
assert 'detect_outliers' in tool_names
|
||
assert 'get_correlation' in tool_names
|
||
|
||
def test_select_tools_for_categorical_data(self):
|
||
"""测试为包含分类字段的数据选择工具。"""
|
||
from src.tools.tool_manager import ToolManager
|
||
|
||
registry = ToolRegistry()
|
||
registry.register(GetColumnDistributionTool())
|
||
registry.register(GetValueCountsTool())
|
||
registry.register(PerformGroupbyTool())
|
||
|
||
manager = ToolManager(registry)
|
||
|
||
# 创建包含分类字段的数据画像
|
||
profile = DataProfile(
|
||
file_path='test.csv',
|
||
row_count=100,
|
||
column_count=1,
|
||
columns=[
|
||
ColumnInfo(name='category', dtype='categorical', missing_rate=0.0, unique_count=5)
|
||
],
|
||
inferred_type='unknown',
|
||
key_fields={},
|
||
quality_score=100.0,
|
||
summary='Test data'
|
||
)
|
||
|
||
tools = manager.select_tools(profile)
|
||
tool_names = [tool.name for tool in tools]
|
||
|
||
# 应该包含分类工具
|
||
assert 'get_column_distribution' in tool_names
|
||
assert 'get_value_counts' in tool_names
|
||
assert 'perform_groupby' in tool_names
|
||
|
||
def test_no_geo_tools_for_non_geo_data(self):
|
||
"""测试不为非地理数据选择地理工具。"""
|
||
from src.tools.tool_manager import ToolManager
|
||
|
||
registry = ToolRegistry()
|
||
registry.register(GetColumnDistributionTool())
|
||
|
||
manager = ToolManager(registry)
|
||
|
||
# 创建不包含地理字段的数据画像
|
||
profile = DataProfile(
|
||
file_path='test.csv',
|
||
row_count=100,
|
||
column_count=1,
|
||
columns=[
|
||
ColumnInfo(name='value', dtype='numeric', missing_rate=0.0, unique_count=50)
|
||
],
|
||
inferred_type='unknown',
|
||
key_fields={},
|
||
quality_score=100.0,
|
||
summary='Test data'
|
||
)
|
||
|
||
tools = manager.select_tools(profile)
|
||
tool_names = [tool.name for tool in tools]
|
||
|
||
# 不应该包含地理工具
|
||
assert 'create_map_visualization' not in tool_names
|
||
|
||
def test_identify_missing_tools(self):
|
||
"""测试识别缺失的工具。"""
|
||
from src.tools.tool_manager import ToolManager
|
||
|
||
# 创建空的工具注册表
|
||
empty_registry = ToolRegistry()
|
||
manager = ToolManager(empty_registry)
|
||
|
||
# 创建包含时间字段的数据画像
|
||
profile = DataProfile(
|
||
file_path='test.csv',
|
||
row_count=100,
|
||
column_count=1,
|
||
columns=[
|
||
ColumnInfo(name='date', dtype='datetime', missing_rate=0.0, unique_count=100)
|
||
],
|
||
inferred_type='unknown',
|
||
key_fields={},
|
||
quality_score=100.0,
|
||
summary='Test data'
|
||
)
|
||
|
||
# 尝试选择工具
|
||
tools = manager.select_tools(profile)
|
||
|
||
# 获取缺失的工具
|
||
missing = manager.get_missing_tools()
|
||
|
||
# 应该识别出缺失的时间序列工具
|
||
assert len(missing) > 0
|
||
assert any(tool in missing for tool in ['get_time_series', 'calculate_trend'])
|
||
|
||
def test_clear_missing_tools(self):
|
||
"""测试清空缺失工具列表。"""
|
||
from src.tools.tool_manager import ToolManager
|
||
|
||
empty_registry = ToolRegistry()
|
||
manager = ToolManager(empty_registry)
|
||
|
||
# 创建数据画像并选择工具(会记录缺失工具)
|
||
profile = DataProfile(
|
||
file_path='test.csv',
|
||
row_count=100,
|
||
column_count=1,
|
||
columns=[
|
||
ColumnInfo(name='date', dtype='datetime', missing_rate=0.0, unique_count=100)
|
||
],
|
||
inferred_type='unknown',
|
||
key_fields={},
|
||
quality_score=100.0,
|
||
summary='Test data'
|
||
)
|
||
|
||
manager.select_tools(profile)
|
||
assert len(manager.get_missing_tools()) > 0
|
||
|
||
# 清空缺失工具列表
|
||
manager.clear_missing_tools()
|
||
assert len(manager.get_missing_tools()) == 0
|
||
|
||
def test_get_tool_descriptions(self):
|
||
"""测试获取工具描述。"""
|
||
from src.tools.tool_manager import ToolManager
|
||
|
||
registry = ToolRegistry()
|
||
tool1 = GetColumnDistributionTool()
|
||
tool2 = CalculateStatisticsTool()
|
||
registry.register(tool1)
|
||
registry.register(tool2)
|
||
|
||
manager = ToolManager(registry)
|
||
|
||
tools = [tool1, tool2]
|
||
descriptions = manager.get_tool_descriptions(tools)
|
||
|
||
assert len(descriptions) == 2
|
||
assert all('name' in desc for desc in descriptions)
|
||
assert all('description' in desc for desc in descriptions)
|
||
assert all('parameters' in desc for desc in descriptions)
|
||
|
||
def test_tool_deduplication(self):
|
||
"""测试工具去重。"""
|
||
from src.tools.tool_manager import ToolManager
|
||
|
||
registry = ToolRegistry()
|
||
# 注册一个工具,它可能被多个类别选中
|
||
tool = GetColumnDistributionTool()
|
||
registry.register(tool)
|
||
|
||
manager = ToolManager(registry)
|
||
|
||
# 创建包含多种类型字段的数据画像
|
||
profile = DataProfile(
|
||
file_path='test.csv',
|
||
row_count=100,
|
||
column_count=2,
|
||
columns=[
|
||
ColumnInfo(name='category', dtype='categorical', missing_rate=0.0, unique_count=5),
|
||
ColumnInfo(name='value', dtype='numeric', missing_rate=0.0, unique_count=50)
|
||
],
|
||
inferred_type='unknown',
|
||
key_fields={},
|
||
quality_score=100.0,
|
||
summary='Test data'
|
||
)
|
||
|
||
tools = manager.select_tools(profile)
|
||
tool_names = [tool.name for tool in tools]
|
||
|
||
# 工具名称应该是唯一的(没有重复)
|
||
assert len(tool_names) == len(set(tool_names))
|