Files
vibe_data_ana/tests/test_tools.py

681 lines
22 KiB
Python
Raw Normal View History

"""工具系统的单元测试。"""
import pytest
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from src.tools.base import AnalysisTool, ToolRegistry
from src.tools.query_tools import (
GetColumnDistributionTool,
GetValueCountsTool,
GetTimeSeriesTool,
GetCorrelationTool
)
from src.tools.stats_tools import (
CalculateStatisticsTool,
PerformGroupbyTool,
DetectOutliersTool,
CalculateTrendTool
)
from src.models import DataProfile, ColumnInfo
class TestGetColumnDistributionTool:
"""测试列分布工具。"""
def test_basic_functionality(self):
"""测试基本功能。"""
tool = GetColumnDistributionTool()
df = pd.DataFrame({
'status': ['open', 'closed', 'open', 'pending', 'closed', 'open']
})
result = tool.execute(df, column='status')
assert 'distribution' in result
assert result['column'] == 'status'
assert result['total_count'] == 6
assert result['unique_count'] == 3
assert len(result['distribution']) == 3
def test_top_n_limit(self):
"""测试 top_n 参数限制。"""
tool = GetColumnDistributionTool()
df = pd.DataFrame({
'value': list(range(20))
})
result = tool.execute(df, column='value', top_n=5)
assert len(result['distribution']) == 5
def test_nonexistent_column(self):
"""测试不存在的列。"""
tool = GetColumnDistributionTool()
df = pd.DataFrame({'col1': [1, 2, 3]})
result = tool.execute(df, column='nonexistent')
assert 'error' in result
class TestGetValueCountsTool:
"""测试值计数工具。"""
def test_basic_functionality(self):
"""测试基本功能。"""
tool = GetValueCountsTool()
df = pd.DataFrame({
'category': ['A', 'B', 'A', 'C', 'B', 'A']
})
result = tool.execute(df, column='category')
assert 'value_counts' in result
assert result['value_counts']['A'] == 3
assert result['value_counts']['B'] == 2
assert result['value_counts']['C'] == 1
def test_normalized_counts(self):
"""测试归一化计数。"""
tool = GetValueCountsTool()
df = pd.DataFrame({
'category': ['A', 'A', 'B', 'B']
})
result = tool.execute(df, column='category', normalize=True)
assert result['normalized'] is True
assert abs(result['value_counts']['A'] - 0.5) < 0.01
assert abs(result['value_counts']['B'] - 0.5) < 0.01
class TestGetTimeSeriesTool:
"""测试时间序列工具。"""
def test_basic_functionality(self):
"""测试基本功能。"""
tool = GetTimeSeriesTool()
dates = pd.date_range('2020-01-01', periods=10, freq='D')
df = pd.DataFrame({
'date': dates,
'value': range(10)
})
result = tool.execute(df, time_column='date', value_column='value', aggregation='sum')
assert 'time_series' in result
assert result['time_column'] == 'date'
assert result['aggregation'] == 'sum'
assert len(result['time_series']) > 0
def test_count_aggregation(self):
"""测试计数聚合。"""
tool = GetTimeSeriesTool()
dates = pd.date_range('2020-01-01', periods=5, freq='D')
df = pd.DataFrame({'date': dates})
result = tool.execute(df, time_column='date', aggregation='count')
assert 'time_series' in result
assert len(result['time_series']) > 0
def test_output_limit(self):
"""测试输出限制不超过100行"""
tool = GetTimeSeriesTool()
dates = pd.date_range('2020-01-01', periods=200, freq='D')
df = pd.DataFrame({'date': dates})
result = tool.execute(df, time_column='date')
assert len(result['time_series']) <= 100
assert result['total_points'] == 200
assert result['returned_points'] == 100
class TestGetCorrelationTool:
"""测试相关性分析工具。"""
def test_basic_functionality(self):
"""测试基本功能。"""
tool = GetCorrelationTool()
df = pd.DataFrame({
'x': [1, 2, 3, 4, 5],
'y': [2, 4, 6, 8, 10],
'z': [1, 1, 1, 1, 1]
})
result = tool.execute(df)
assert 'correlation_matrix' in result
assert 'x' in result['correlation_matrix']
assert 'y' in result['correlation_matrix']
# x 和 y 完全正相关
assert abs(result['correlation_matrix']['x']['y'] - 1.0) < 0.01
def test_insufficient_numeric_columns(self):
"""测试数值列不足的情况。"""
tool = GetCorrelationTool()
df = pd.DataFrame({
'x': [1, 2, 3],
'text': ['a', 'b', 'c']
})
result = tool.execute(df)
assert 'error' in result
class TestCalculateStatisticsTool:
"""测试统计计算工具。"""
def test_basic_functionality(self):
"""测试基本功能。"""
tool = CalculateStatisticsTool()
df = pd.DataFrame({
'values': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
})
result = tool.execute(df, column='values')
assert result['mean'] == 5.5
assert result['median'] == 5.5
assert result['min'] == 1
assert result['max'] == 10
assert result['count'] == 10
def test_non_numeric_column(self):
"""测试非数值列。"""
tool = CalculateStatisticsTool()
df = pd.DataFrame({
'text': ['a', 'b', 'c']
})
result = tool.execute(df, column='text')
assert 'error' in result
class TestPerformGroupbyTool:
"""测试分组聚合工具。"""
def test_basic_functionality(self):
"""测试基本功能。"""
tool = PerformGroupbyTool()
df = pd.DataFrame({
'category': ['A', 'B', 'A', 'B', 'A'],
'value': [10, 20, 30, 40, 50]
})
result = tool.execute(df, group_by='category', value_column='value', aggregation='sum')
assert 'groups' in result
assert len(result['groups']) == 2
# 找到 A 组的总和
group_a = next(g for g in result['groups'] if g['group'] == 'A')
assert group_a['value'] == 90 # 10 + 30 + 50
def test_count_aggregation(self):
"""测试计数聚合。"""
tool = PerformGroupbyTool()
df = pd.DataFrame({
'category': ['A', 'B', 'A', 'B', 'A']
})
result = tool.execute(df, group_by='category')
assert len(result['groups']) == 2
group_a = next(g for g in result['groups'] if g['group'] == 'A')
assert group_a['value'] == 3
def test_output_limit(self):
"""测试输出限制不超过100组"""
tool = PerformGroupbyTool()
df = pd.DataFrame({
'category': [f'cat_{i}' for i in range(200)],
'value': range(200)
})
result = tool.execute(df, group_by='category', value_column='value', aggregation='sum')
assert len(result['groups']) <= 100
assert result['total_groups'] == 200
assert result['returned_groups'] == 100
class TestDetectOutliersTool:
"""测试异常值检测工具。"""
def test_iqr_method(self):
"""测试 IQR 方法。"""
tool = DetectOutliersTool()
# 创建包含明显异常值的数据
df = pd.DataFrame({
'values': [1, 2, 3, 4, 5, 6, 7, 8, 9, 100]
})
result = tool.execute(df, column='values', method='iqr')
assert result['outlier_count'] > 0
assert 100 in result['outlier_values']
def test_zscore_method(self):
"""测试 Z-score 方法。"""
tool = DetectOutliersTool()
df = pd.DataFrame({
'values': [1, 2, 3, 4, 5, 6, 7, 8, 9, 100]
})
result = tool.execute(df, column='values', method='zscore', threshold=2)
assert result['outlier_count'] > 0
assert result['method'] == 'zscore'
class TestCalculateTrendTool:
"""测试趋势计算工具。"""
def test_increasing_trend(self):
"""测试上升趋势。"""
tool = CalculateTrendTool()
dates = pd.date_range('2020-01-01', periods=10, freq='D')
df = pd.DataFrame({
'date': dates,
'value': range(10)
})
result = tool.execute(df, time_column='date', value_column='value')
assert result['trend'] == 'increasing'
assert result['slope'] > 0
assert result['r_squared'] > 0.9 # 完美线性关系
def test_decreasing_trend(self):
"""测试下降趋势。"""
tool = CalculateTrendTool()
dates = pd.date_range('2020-01-01', periods=10, freq='D')
df = pd.DataFrame({
'date': dates,
'value': list(range(10, 0, -1))
})
result = tool.execute(df, time_column='date', value_column='value')
assert result['trend'] == 'decreasing'
assert result['slope'] < 0
class TestToolParameterValidation:
"""测试工具参数验证。"""
def test_missing_required_parameter(self):
"""测试缺少必需参数。"""
tool = GetColumnDistributionTool()
df = pd.DataFrame({'col': [1, 2, 3]})
# 不提供必需的 column 参数
result = tool.execute(df)
# 应该返回错误或引发异常
assert 'error' in result or result is None
def test_invalid_aggregation_method(self):
"""测试无效的聚合方法。"""
tool = PerformGroupbyTool()
df = pd.DataFrame({
'category': ['A', 'B'],
'value': [1, 2]
})
result = tool.execute(df, group_by='category', value_column='value', aggregation='invalid')
assert 'error' in result
class TestToolErrorHandling:
"""测试工具错误处理。"""
def test_empty_dataframe(self):
"""测试空 DataFrame。"""
tool = CalculateStatisticsTool()
df = pd.DataFrame()
result = tool.execute(df, column='nonexistent')
assert 'error' in result
def test_all_null_values(self):
"""测试全部为空值的列。"""
tool = CalculateStatisticsTool()
df = pd.DataFrame({
'values': [None, None, None]
})
result = tool.execute(df, column='values')
# 应该处理空值情况
assert 'error' in result or result['count'] == 0
def test_invalid_date_column(self):
"""测试无效的日期列。"""
tool = GetTimeSeriesTool()
df = pd.DataFrame({
'not_date': ['a', 'b', 'c']
})
result = tool.execute(df, time_column='not_date')
assert 'error' in result
class TestToolRegistry:
"""测试工具注册表。"""
def test_register_and_retrieve(self):
"""测试注册和检索工具。"""
registry = ToolRegistry()
tool = GetColumnDistributionTool()
registry.register(tool)
retrieved = registry.get_tool(tool.name)
assert retrieved.name == tool.name
def test_unregister(self):
"""测试注销工具。"""
registry = ToolRegistry()
tool = GetColumnDistributionTool()
registry.register(tool)
registry.unregister(tool.name)
with pytest.raises(KeyError):
registry.get_tool(tool.name)
def test_list_tools(self):
"""测试列出所有工具。"""
registry = ToolRegistry()
tool1 = GetColumnDistributionTool()
tool2 = GetValueCountsTool()
registry.register(tool1)
registry.register(tool2)
tools = registry.list_tools()
assert len(tools) == 2
assert tool1.name in tools
assert tool2.name in tools
def test_get_applicable_tools(self):
"""测试获取适用的工具。"""
registry = ToolRegistry()
# 注册所有工具
registry.register(GetColumnDistributionTool())
registry.register(CalculateStatisticsTool())
registry.register(GetTimeSeriesTool())
# 创建包含数值和时间列的数据画像
profile = DataProfile(
file_path='test.csv',
row_count=100,
column_count=2,
columns=[
ColumnInfo(name='value', dtype='numeric', missing_rate=0.0, unique_count=50),
ColumnInfo(name='date', dtype='datetime', missing_rate=0.0, unique_count=100)
],
inferred_type='unknown'
)
applicable = registry.get_applicable_tools(profile)
# 所有工具都应该适用GetColumnDistributionTool 适用于所有数据)
assert len(applicable) > 0
class TestToolManager:
"""测试工具管理器。"""
def test_select_tools_for_datetime_data(self):
"""测试为包含时间字段的数据选择工具。"""
from src.tools.tool_manager import ToolManager
# 创建工具注册表并注册所有工具
registry = ToolRegistry()
registry.register(GetTimeSeriesTool())
registry.register(CalculateTrendTool())
registry.register(GetColumnDistributionTool())
manager = ToolManager(registry)
# 创建包含时间字段的数据画像
profile = DataProfile(
file_path='test.csv',
row_count=100,
column_count=1,
columns=[
ColumnInfo(name='date', dtype='datetime', missing_rate=0.0, unique_count=100)
],
inferred_type='unknown',
key_fields={},
quality_score=100.0,
summary='Test data'
)
tools = manager.select_tools(profile)
tool_names = [tool.name for tool in tools]
# 应该包含时间序列工具
assert 'get_time_series' in tool_names
assert 'calculate_trend' in tool_names
def test_select_tools_for_numeric_data(self):
"""测试为包含数值字段的数据选择工具。"""
from src.tools.tool_manager import ToolManager
registry = ToolRegistry()
registry.register(CalculateStatisticsTool())
registry.register(DetectOutliersTool())
registry.register(GetCorrelationTool())
manager = ToolManager(registry)
# 创建包含数值字段的数据画像
profile = DataProfile(
file_path='test.csv',
row_count=100,
column_count=2,
columns=[
ColumnInfo(name='value1', dtype='numeric', missing_rate=0.0, unique_count=50),
ColumnInfo(name='value2', dtype='numeric', missing_rate=0.0, unique_count=50)
],
inferred_type='unknown',
key_fields={},
quality_score=100.0,
summary='Test data'
)
tools = manager.select_tools(profile)
tool_names = [tool.name for tool in tools]
# 应该包含统计工具
assert 'calculate_statistics' in tool_names
assert 'detect_outliers' in tool_names
assert 'get_correlation' in tool_names
def test_select_tools_for_categorical_data(self):
"""测试为包含分类字段的数据选择工具。"""
from src.tools.tool_manager import ToolManager
registry = ToolRegistry()
registry.register(GetColumnDistributionTool())
registry.register(GetValueCountsTool())
registry.register(PerformGroupbyTool())
manager = ToolManager(registry)
# 创建包含分类字段的数据画像
profile = DataProfile(
file_path='test.csv',
row_count=100,
column_count=1,
columns=[
ColumnInfo(name='category', dtype='categorical', missing_rate=0.0, unique_count=5)
],
inferred_type='unknown',
key_fields={},
quality_score=100.0,
summary='Test data'
)
tools = manager.select_tools(profile)
tool_names = [tool.name for tool in tools]
# 应该包含分类工具
assert 'get_column_distribution' in tool_names
assert 'get_value_counts' in tool_names
assert 'perform_groupby' in tool_names
def test_no_geo_tools_for_non_geo_data(self):
"""测试不为非地理数据选择地理工具。"""
from src.tools.tool_manager import ToolManager
registry = ToolRegistry()
registry.register(GetColumnDistributionTool())
manager = ToolManager(registry)
# 创建不包含地理字段的数据画像
profile = DataProfile(
file_path='test.csv',
row_count=100,
column_count=1,
columns=[
ColumnInfo(name='value', dtype='numeric', missing_rate=0.0, unique_count=50)
],
inferred_type='unknown',
key_fields={},
quality_score=100.0,
summary='Test data'
)
tools = manager.select_tools(profile)
tool_names = [tool.name for tool in tools]
# 不应该包含地理工具
assert 'create_map_visualization' not in tool_names
def test_identify_missing_tools(self):
"""测试识别缺失的工具。"""
from src.tools.tool_manager import ToolManager
# 创建空的工具注册表
empty_registry = ToolRegistry()
manager = ToolManager(empty_registry)
# 创建包含时间字段的数据画像
profile = DataProfile(
file_path='test.csv',
row_count=100,
column_count=1,
columns=[
ColumnInfo(name='date', dtype='datetime', missing_rate=0.0, unique_count=100)
],
inferred_type='unknown',
key_fields={},
quality_score=100.0,
summary='Test data'
)
# 尝试选择工具
tools = manager.select_tools(profile)
# 获取缺失的工具
missing = manager.get_missing_tools()
# 应该识别出缺失的时间序列工具
assert len(missing) > 0
assert any(tool in missing for tool in ['get_time_series', 'calculate_trend'])
def test_clear_missing_tools(self):
"""测试清空缺失工具列表。"""
from src.tools.tool_manager import ToolManager
empty_registry = ToolRegistry()
manager = ToolManager(empty_registry)
# 创建数据画像并选择工具(会记录缺失工具)
profile = DataProfile(
file_path='test.csv',
row_count=100,
column_count=1,
columns=[
ColumnInfo(name='date', dtype='datetime', missing_rate=0.0, unique_count=100)
],
inferred_type='unknown',
key_fields={},
quality_score=100.0,
summary='Test data'
)
manager.select_tools(profile)
assert len(manager.get_missing_tools()) > 0
# 清空缺失工具列表
manager.clear_missing_tools()
assert len(manager.get_missing_tools()) == 0
def test_get_tool_descriptions(self):
"""测试获取工具描述。"""
from src.tools.tool_manager import ToolManager
registry = ToolRegistry()
tool1 = GetColumnDistributionTool()
tool2 = CalculateStatisticsTool()
registry.register(tool1)
registry.register(tool2)
manager = ToolManager(registry)
tools = [tool1, tool2]
descriptions = manager.get_tool_descriptions(tools)
assert len(descriptions) == 2
assert all('name' in desc for desc in descriptions)
assert all('description' in desc for desc in descriptions)
assert all('parameters' in desc for desc in descriptions)
def test_tool_deduplication(self):
"""测试工具去重。"""
from src.tools.tool_manager import ToolManager
registry = ToolRegistry()
# 注册一个工具,它可能被多个类别选中
tool = GetColumnDistributionTool()
registry.register(tool)
manager = ToolManager(registry)
# 创建包含多种类型字段的数据画像
profile = DataProfile(
file_path='test.csv',
row_count=100,
column_count=2,
columns=[
ColumnInfo(name='category', dtype='categorical', missing_rate=0.0, unique_count=5),
ColumnInfo(name='value', dtype='numeric', missing_rate=0.0, unique_count=50)
],
inferred_type='unknown',
key_fields={},
quality_score=100.0,
summary='Test data'
)
tools = manager.select_tools(profile)
tool_names = [tool.name for tool in tools]
# 工具名称应该是唯一的(没有重复)
assert len(tool_names) == len(set(tool_names))