Complete AI Data Analysis Agent implementation with 95.7% test coverage

This commit is contained in:
2026-03-07 00:04:29 +08:00
parent 621e546b43
commit 7071b1f730
245 changed files with 22612 additions and 2211 deletions

680
tests/test_tools.py Normal file
View File

@@ -0,0 +1,680 @@
"""工具系统的单元测试。"""
import pytest
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from src.tools.base import AnalysisTool, ToolRegistry
from src.tools.query_tools import (
GetColumnDistributionTool,
GetValueCountsTool,
GetTimeSeriesTool,
GetCorrelationTool
)
from src.tools.stats_tools import (
CalculateStatisticsTool,
PerformGroupbyTool,
DetectOutliersTool,
CalculateTrendTool
)
from src.models import DataProfile, ColumnInfo
class TestGetColumnDistributionTool:
"""测试列分布工具。"""
def test_basic_functionality(self):
"""测试基本功能。"""
tool = GetColumnDistributionTool()
df = pd.DataFrame({
'status': ['open', 'closed', 'open', 'pending', 'closed', 'open']
})
result = tool.execute(df, column='status')
assert 'distribution' in result
assert result['column'] == 'status'
assert result['total_count'] == 6
assert result['unique_count'] == 3
assert len(result['distribution']) == 3
def test_top_n_limit(self):
"""测试 top_n 参数限制。"""
tool = GetColumnDistributionTool()
df = pd.DataFrame({
'value': list(range(20))
})
result = tool.execute(df, column='value', top_n=5)
assert len(result['distribution']) == 5
def test_nonexistent_column(self):
"""测试不存在的列。"""
tool = GetColumnDistributionTool()
df = pd.DataFrame({'col1': [1, 2, 3]})
result = tool.execute(df, column='nonexistent')
assert 'error' in result
class TestGetValueCountsTool:
"""测试值计数工具。"""
def test_basic_functionality(self):
"""测试基本功能。"""
tool = GetValueCountsTool()
df = pd.DataFrame({
'category': ['A', 'B', 'A', 'C', 'B', 'A']
})
result = tool.execute(df, column='category')
assert 'value_counts' in result
assert result['value_counts']['A'] == 3
assert result['value_counts']['B'] == 2
assert result['value_counts']['C'] == 1
def test_normalized_counts(self):
"""测试归一化计数。"""
tool = GetValueCountsTool()
df = pd.DataFrame({
'category': ['A', 'A', 'B', 'B']
})
result = tool.execute(df, column='category', normalize=True)
assert result['normalized'] is True
assert abs(result['value_counts']['A'] - 0.5) < 0.01
assert abs(result['value_counts']['B'] - 0.5) < 0.01
class TestGetTimeSeriesTool:
"""测试时间序列工具。"""
def test_basic_functionality(self):
"""测试基本功能。"""
tool = GetTimeSeriesTool()
dates = pd.date_range('2020-01-01', periods=10, freq='D')
df = pd.DataFrame({
'date': dates,
'value': range(10)
})
result = tool.execute(df, time_column='date', value_column='value', aggregation='sum')
assert 'time_series' in result
assert result['time_column'] == 'date'
assert result['aggregation'] == 'sum'
assert len(result['time_series']) > 0
def test_count_aggregation(self):
"""测试计数聚合。"""
tool = GetTimeSeriesTool()
dates = pd.date_range('2020-01-01', periods=5, freq='D')
df = pd.DataFrame({'date': dates})
result = tool.execute(df, time_column='date', aggregation='count')
assert 'time_series' in result
assert len(result['time_series']) > 0
def test_output_limit(self):
"""测试输出限制不超过100行"""
tool = GetTimeSeriesTool()
dates = pd.date_range('2020-01-01', periods=200, freq='D')
df = pd.DataFrame({'date': dates})
result = tool.execute(df, time_column='date')
assert len(result['time_series']) <= 100
assert result['total_points'] == 200
assert result['returned_points'] == 100
class TestGetCorrelationTool:
"""测试相关性分析工具。"""
def test_basic_functionality(self):
"""测试基本功能。"""
tool = GetCorrelationTool()
df = pd.DataFrame({
'x': [1, 2, 3, 4, 5],
'y': [2, 4, 6, 8, 10],
'z': [1, 1, 1, 1, 1]
})
result = tool.execute(df)
assert 'correlation_matrix' in result
assert 'x' in result['correlation_matrix']
assert 'y' in result['correlation_matrix']
# x 和 y 完全正相关
assert abs(result['correlation_matrix']['x']['y'] - 1.0) < 0.01
def test_insufficient_numeric_columns(self):
"""测试数值列不足的情况。"""
tool = GetCorrelationTool()
df = pd.DataFrame({
'x': [1, 2, 3],
'text': ['a', 'b', 'c']
})
result = tool.execute(df)
assert 'error' in result
class TestCalculateStatisticsTool:
"""测试统计计算工具。"""
def test_basic_functionality(self):
"""测试基本功能。"""
tool = CalculateStatisticsTool()
df = pd.DataFrame({
'values': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
})
result = tool.execute(df, column='values')
assert result['mean'] == 5.5
assert result['median'] == 5.5
assert result['min'] == 1
assert result['max'] == 10
assert result['count'] == 10
def test_non_numeric_column(self):
"""测试非数值列。"""
tool = CalculateStatisticsTool()
df = pd.DataFrame({
'text': ['a', 'b', 'c']
})
result = tool.execute(df, column='text')
assert 'error' in result
class TestPerformGroupbyTool:
"""测试分组聚合工具。"""
def test_basic_functionality(self):
"""测试基本功能。"""
tool = PerformGroupbyTool()
df = pd.DataFrame({
'category': ['A', 'B', 'A', 'B', 'A'],
'value': [10, 20, 30, 40, 50]
})
result = tool.execute(df, group_by='category', value_column='value', aggregation='sum')
assert 'groups' in result
assert len(result['groups']) == 2
# 找到 A 组的总和
group_a = next(g for g in result['groups'] if g['group'] == 'A')
assert group_a['value'] == 90 # 10 + 30 + 50
def test_count_aggregation(self):
"""测试计数聚合。"""
tool = PerformGroupbyTool()
df = pd.DataFrame({
'category': ['A', 'B', 'A', 'B', 'A']
})
result = tool.execute(df, group_by='category')
assert len(result['groups']) == 2
group_a = next(g for g in result['groups'] if g['group'] == 'A')
assert group_a['value'] == 3
def test_output_limit(self):
"""测试输出限制不超过100组"""
tool = PerformGroupbyTool()
df = pd.DataFrame({
'category': [f'cat_{i}' for i in range(200)],
'value': range(200)
})
result = tool.execute(df, group_by='category', value_column='value', aggregation='sum')
assert len(result['groups']) <= 100
assert result['total_groups'] == 200
assert result['returned_groups'] == 100
class TestDetectOutliersTool:
"""测试异常值检测工具。"""
def test_iqr_method(self):
"""测试 IQR 方法。"""
tool = DetectOutliersTool()
# 创建包含明显异常值的数据
df = pd.DataFrame({
'values': [1, 2, 3, 4, 5, 6, 7, 8, 9, 100]
})
result = tool.execute(df, column='values', method='iqr')
assert result['outlier_count'] > 0
assert 100 in result['outlier_values']
def test_zscore_method(self):
"""测试 Z-score 方法。"""
tool = DetectOutliersTool()
df = pd.DataFrame({
'values': [1, 2, 3, 4, 5, 6, 7, 8, 9, 100]
})
result = tool.execute(df, column='values', method='zscore', threshold=2)
assert result['outlier_count'] > 0
assert result['method'] == 'zscore'
class TestCalculateTrendTool:
"""测试趋势计算工具。"""
def test_increasing_trend(self):
"""测试上升趋势。"""
tool = CalculateTrendTool()
dates = pd.date_range('2020-01-01', periods=10, freq='D')
df = pd.DataFrame({
'date': dates,
'value': range(10)
})
result = tool.execute(df, time_column='date', value_column='value')
assert result['trend'] == 'increasing'
assert result['slope'] > 0
assert result['r_squared'] > 0.9 # 完美线性关系
def test_decreasing_trend(self):
"""测试下降趋势。"""
tool = CalculateTrendTool()
dates = pd.date_range('2020-01-01', periods=10, freq='D')
df = pd.DataFrame({
'date': dates,
'value': list(range(10, 0, -1))
})
result = tool.execute(df, time_column='date', value_column='value')
assert result['trend'] == 'decreasing'
assert result['slope'] < 0
class TestToolParameterValidation:
"""测试工具参数验证。"""
def test_missing_required_parameter(self):
"""测试缺少必需参数。"""
tool = GetColumnDistributionTool()
df = pd.DataFrame({'col': [1, 2, 3]})
# 不提供必需的 column 参数
result = tool.execute(df)
# 应该返回错误或引发异常
assert 'error' in result or result is None
def test_invalid_aggregation_method(self):
"""测试无效的聚合方法。"""
tool = PerformGroupbyTool()
df = pd.DataFrame({
'category': ['A', 'B'],
'value': [1, 2]
})
result = tool.execute(df, group_by='category', value_column='value', aggregation='invalid')
assert 'error' in result
class TestToolErrorHandling:
"""测试工具错误处理。"""
def test_empty_dataframe(self):
"""测试空 DataFrame。"""
tool = CalculateStatisticsTool()
df = pd.DataFrame()
result = tool.execute(df, column='nonexistent')
assert 'error' in result
def test_all_null_values(self):
"""测试全部为空值的列。"""
tool = CalculateStatisticsTool()
df = pd.DataFrame({
'values': [None, None, None]
})
result = tool.execute(df, column='values')
# 应该处理空值情况
assert 'error' in result or result['count'] == 0
def test_invalid_date_column(self):
"""测试无效的日期列。"""
tool = GetTimeSeriesTool()
df = pd.DataFrame({
'not_date': ['a', 'b', 'c']
})
result = tool.execute(df, time_column='not_date')
assert 'error' in result
class TestToolRegistry:
"""测试工具注册表。"""
def test_register_and_retrieve(self):
"""测试注册和检索工具。"""
registry = ToolRegistry()
tool = GetColumnDistributionTool()
registry.register(tool)
retrieved = registry.get_tool(tool.name)
assert retrieved.name == tool.name
def test_unregister(self):
"""测试注销工具。"""
registry = ToolRegistry()
tool = GetColumnDistributionTool()
registry.register(tool)
registry.unregister(tool.name)
with pytest.raises(KeyError):
registry.get_tool(tool.name)
def test_list_tools(self):
"""测试列出所有工具。"""
registry = ToolRegistry()
tool1 = GetColumnDistributionTool()
tool2 = GetValueCountsTool()
registry.register(tool1)
registry.register(tool2)
tools = registry.list_tools()
assert len(tools) == 2
assert tool1.name in tools
assert tool2.name in tools
def test_get_applicable_tools(self):
"""测试获取适用的工具。"""
registry = ToolRegistry()
# 注册所有工具
registry.register(GetColumnDistributionTool())
registry.register(CalculateStatisticsTool())
registry.register(GetTimeSeriesTool())
# 创建包含数值和时间列的数据画像
profile = DataProfile(
file_path='test.csv',
row_count=100,
column_count=2,
columns=[
ColumnInfo(name='value', dtype='numeric', missing_rate=0.0, unique_count=50),
ColumnInfo(name='date', dtype='datetime', missing_rate=0.0, unique_count=100)
],
inferred_type='unknown'
)
applicable = registry.get_applicable_tools(profile)
# 所有工具都应该适用GetColumnDistributionTool 适用于所有数据)
assert len(applicable) > 0
class TestToolManager:
"""测试工具管理器。"""
def test_select_tools_for_datetime_data(self):
"""测试为包含时间字段的数据选择工具。"""
from src.tools.tool_manager import ToolManager
# 创建工具注册表并注册所有工具
registry = ToolRegistry()
registry.register(GetTimeSeriesTool())
registry.register(CalculateTrendTool())
registry.register(GetColumnDistributionTool())
manager = ToolManager(registry)
# 创建包含时间字段的数据画像
profile = DataProfile(
file_path='test.csv',
row_count=100,
column_count=1,
columns=[
ColumnInfo(name='date', dtype='datetime', missing_rate=0.0, unique_count=100)
],
inferred_type='unknown',
key_fields={},
quality_score=100.0,
summary='Test data'
)
tools = manager.select_tools(profile)
tool_names = [tool.name for tool in tools]
# 应该包含时间序列工具
assert 'get_time_series' in tool_names
assert 'calculate_trend' in tool_names
def test_select_tools_for_numeric_data(self):
"""测试为包含数值字段的数据选择工具。"""
from src.tools.tool_manager import ToolManager
registry = ToolRegistry()
registry.register(CalculateStatisticsTool())
registry.register(DetectOutliersTool())
registry.register(GetCorrelationTool())
manager = ToolManager(registry)
# 创建包含数值字段的数据画像
profile = DataProfile(
file_path='test.csv',
row_count=100,
column_count=2,
columns=[
ColumnInfo(name='value1', dtype='numeric', missing_rate=0.0, unique_count=50),
ColumnInfo(name='value2', dtype='numeric', missing_rate=0.0, unique_count=50)
],
inferred_type='unknown',
key_fields={},
quality_score=100.0,
summary='Test data'
)
tools = manager.select_tools(profile)
tool_names = [tool.name for tool in tools]
# 应该包含统计工具
assert 'calculate_statistics' in tool_names
assert 'detect_outliers' in tool_names
assert 'get_correlation' in tool_names
def test_select_tools_for_categorical_data(self):
"""测试为包含分类字段的数据选择工具。"""
from src.tools.tool_manager import ToolManager
registry = ToolRegistry()
registry.register(GetColumnDistributionTool())
registry.register(GetValueCountsTool())
registry.register(PerformGroupbyTool())
manager = ToolManager(registry)
# 创建包含分类字段的数据画像
profile = DataProfile(
file_path='test.csv',
row_count=100,
column_count=1,
columns=[
ColumnInfo(name='category', dtype='categorical', missing_rate=0.0, unique_count=5)
],
inferred_type='unknown',
key_fields={},
quality_score=100.0,
summary='Test data'
)
tools = manager.select_tools(profile)
tool_names = [tool.name for tool in tools]
# 应该包含分类工具
assert 'get_column_distribution' in tool_names
assert 'get_value_counts' in tool_names
assert 'perform_groupby' in tool_names
def test_no_geo_tools_for_non_geo_data(self):
"""测试不为非地理数据选择地理工具。"""
from src.tools.tool_manager import ToolManager
registry = ToolRegistry()
registry.register(GetColumnDistributionTool())
manager = ToolManager(registry)
# 创建不包含地理字段的数据画像
profile = DataProfile(
file_path='test.csv',
row_count=100,
column_count=1,
columns=[
ColumnInfo(name='value', dtype='numeric', missing_rate=0.0, unique_count=50)
],
inferred_type='unknown',
key_fields={},
quality_score=100.0,
summary='Test data'
)
tools = manager.select_tools(profile)
tool_names = [tool.name for tool in tools]
# 不应该包含地理工具
assert 'create_map_visualization' not in tool_names
def test_identify_missing_tools(self):
"""测试识别缺失的工具。"""
from src.tools.tool_manager import ToolManager
# 创建空的工具注册表
empty_registry = ToolRegistry()
manager = ToolManager(empty_registry)
# 创建包含时间字段的数据画像
profile = DataProfile(
file_path='test.csv',
row_count=100,
column_count=1,
columns=[
ColumnInfo(name='date', dtype='datetime', missing_rate=0.0, unique_count=100)
],
inferred_type='unknown',
key_fields={},
quality_score=100.0,
summary='Test data'
)
# 尝试选择工具
tools = manager.select_tools(profile)
# 获取缺失的工具
missing = manager.get_missing_tools()
# 应该识别出缺失的时间序列工具
assert len(missing) > 0
assert any(tool in missing for tool in ['get_time_series', 'calculate_trend'])
def test_clear_missing_tools(self):
"""测试清空缺失工具列表。"""
from src.tools.tool_manager import ToolManager
empty_registry = ToolRegistry()
manager = ToolManager(empty_registry)
# 创建数据画像并选择工具(会记录缺失工具)
profile = DataProfile(
file_path='test.csv',
row_count=100,
column_count=1,
columns=[
ColumnInfo(name='date', dtype='datetime', missing_rate=0.0, unique_count=100)
],
inferred_type='unknown',
key_fields={},
quality_score=100.0,
summary='Test data'
)
manager.select_tools(profile)
assert len(manager.get_missing_tools()) > 0
# 清空缺失工具列表
manager.clear_missing_tools()
assert len(manager.get_missing_tools()) == 0
def test_get_tool_descriptions(self):
"""测试获取工具描述。"""
from src.tools.tool_manager import ToolManager
registry = ToolRegistry()
tool1 = GetColumnDistributionTool()
tool2 = CalculateStatisticsTool()
registry.register(tool1)
registry.register(tool2)
manager = ToolManager(registry)
tools = [tool1, tool2]
descriptions = manager.get_tool_descriptions(tools)
assert len(descriptions) == 2
assert all('name' in desc for desc in descriptions)
assert all('description' in desc for desc in descriptions)
assert all('parameters' in desc for desc in descriptions)
def test_tool_deduplication(self):
"""测试工具去重。"""
from src.tools.tool_manager import ToolManager
registry = ToolRegistry()
# 注册一个工具,它可能被多个类别选中
tool = GetColumnDistributionTool()
registry.register(tool)
manager = ToolManager(registry)
# 创建包含多种类型字段的数据画像
profile = DataProfile(
file_path='test.csv',
row_count=100,
column_count=2,
columns=[
ColumnInfo(name='category', dtype='categorical', missing_rate=0.0, unique_count=5),
ColumnInfo(name='value', dtype='numeric', missing_rate=0.0, unique_count=50)
],
inferred_type='unknown',
key_fields={},
quality_score=100.0,
summary='Test data'
)
tools = manager.select_tools(profile)
tool_names = [tool.name for tool in tools]
# 工具名称应该是唯一的(没有重复)
assert len(tool_names) == len(set(tool_names))